21 changed files with 10564 additions and 2954 deletions
@ -0,0 +1,108 @@ |
|||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
// Licensed under the MIT License.
|
||||
|
|
||||
|
// This header is to expose a context for cuda custom ops.
|
||||
|
// By the context, a custom cuda operator could fetch existing resources,
|
||||
|
// such as cuda stream and cudnn handle, for reusing.
|
||||
|
|
||||
|
// For concrete usage, pls find page here:
|
||||
|
// https://onnxruntime.ai/docs/reference/operators/add-custom-op.html#custom-ops-for-cuda-and-rocm
|
||||
|
|
||||
|
#pragma once |
||||
|
|
||||
|
#define ORT_CUDA_CTX |
||||
|
|
||||
|
#include <cuda.h> |
||||
|
#include <cuda_runtime.h> |
||||
|
#ifndef USE_CUDA_MINIMAL |
||||
|
#include <cublas_v2.h> |
||||
|
#include <cudnn.h> |
||||
|
#endif |
||||
|
|
||||
|
#include "core/providers/cuda/cuda_resource.h" |
||||
|
#include "core/providers/custom_op_context.h" |
||||
|
|
||||
|
namespace Ort { |
||||
|
|
||||
|
namespace Custom { |
||||
|
|
||||
|
struct CudaContext : public CustomOpContext { |
||||
|
cudaStream_t cuda_stream = {}; |
||||
|
cudnnHandle_t cudnn_handle = {}; |
||||
|
cublasHandle_t cublas_handle = {}; |
||||
|
OrtAllocator* deferred_cpu_allocator = {}; |
||||
|
// below are cuda ep options
|
||||
|
int16_t device_id = 0; |
||||
|
int32_t arena_extend_strategy = 0; |
||||
|
int32_t cudnn_conv_algo_search = 0; |
||||
|
bool cudnn_conv_use_max_workspace = true; |
||||
|
bool cudnn_conv1d_pad_to_nc1d = false; |
||||
|
bool enable_skip_layer_norm_strict_mode = false; |
||||
|
bool prefer_nhwc = false; |
||||
|
bool use_tf32 = true; |
||||
|
bool fuse_conv_bias = true; |
||||
|
|
||||
|
void Init(const OrtKernelContext& kernel_ctx) { |
||||
|
cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t); |
||||
|
cudnn_handle = FetchResource<cudnnHandle_t>(kernel_ctx, CudaResource::cudnn_handle_t); |
||||
|
cublas_handle = FetchResource<cublasHandle_t>(kernel_ctx, CudaResource::cublas_handle_t); |
||||
|
deferred_cpu_allocator = FetchResource<OrtAllocator*>(kernel_ctx, CudaResource::deferred_cpu_allocator_t); |
||||
|
|
||||
|
device_id = FetchResource<int16_t>(kernel_ctx, CudaResource::device_id_t); |
||||
|
arena_extend_strategy = FetchResource<int32_t>(kernel_ctx, CudaResource::arena_extend_strategy_t); |
||||
|
cudnn_conv_algo_search = FetchResource<int32_t>(kernel_ctx, CudaResource::cudnn_conv_algo_search_t); |
||||
|
cudnn_conv_use_max_workspace = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv_use_max_workspace_t); |
||||
|
|
||||
|
cudnn_conv1d_pad_to_nc1d = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t); |
||||
|
enable_skip_layer_norm_strict_mode = FetchResource<bool>( |
||||
|
kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t); |
||||
|
prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t); |
||||
|
use_tf32 = FetchResource<bool>(kernel_ctx, CudaResource::use_tf32_t); |
||||
|
fuse_conv_bias = FetchResource<bool>(kernel_ctx, CudaResource::fuse_conv_bias_t); |
||||
|
} |
||||
|
|
||||
|
template <typename T> |
||||
|
T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) { |
||||
|
if constexpr (sizeof(T) > sizeof(void*)) { |
||||
|
ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), |
||||
|
OrtErrorCode::ORT_INVALID_ARGUMENT); |
||||
|
} |
||||
|
const auto& ort_api = Ort::GetApi(); |
||||
|
void* resource = {}; |
||||
|
OrtStatus* status = ort_api.KernelContext_GetResource( |
||||
|
&kernel_ctx, ORT_CUDA_RESOURCE_VERSION, resource_type, &resource); |
||||
|
if (status) { |
||||
|
ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resource type: " + std::to_string(resource_type), |
||||
|
OrtErrorCode::ORT_RUNTIME_EXCEPTION); |
||||
|
} |
||||
|
T t = {}; |
||||
|
memcpy(&t, &resource, sizeof(T)); |
||||
|
return t; |
||||
|
} |
||||
|
|
||||
|
void* AllocDeferredCpuMem(size_t size) const { |
||||
|
if (0 == size) { |
||||
|
return {}; |
||||
|
} |
||||
|
const auto& ort_api = Ort::GetApi(); |
||||
|
void* mem = {}; |
||||
|
auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem); |
||||
|
if (status) { |
||||
|
ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); |
||||
|
} |
||||
|
return mem; |
||||
|
} |
||||
|
|
||||
|
void FreeDeferredCpuMem(void* mem) const { |
||||
|
if (mem) { |
||||
|
const auto& ort_api = Ort::GetApi(); |
||||
|
auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem); |
||||
|
if (status) { |
||||
|
ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
}; |
||||
|
|
||||
|
} // namespace Custom
|
||||
|
} // namespace Ort
|
@ -0,0 +1,23 @@ |
|||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
// Licensed under the MIT License.
|
||||
|
|
||||
|
#include "core/providers/resource.h" |
||||
|
|
||||
|
#define ORT_CUDA_RESOURCE_VERSION 3 |
||||
|
|
||||
|
enum CudaResource : int { |
||||
|
cuda_stream_t = cuda_resource_offset, // 10000
|
||||
|
cudnn_handle_t, |
||||
|
cublas_handle_t, |
||||
|
deferred_cpu_allocator_t, |
||||
|
// below are cuda ep options
|
||||
|
device_id_t, // 10004
|
||||
|
arena_extend_strategy_t, |
||||
|
cudnn_conv_algo_search_t, |
||||
|
cudnn_conv_use_max_workspace_t, |
||||
|
cudnn_conv1d_pad_to_nc1d_t, |
||||
|
enable_skip_layer_norm_strict_mode_t, |
||||
|
prefer_nhwc_t, |
||||
|
use_tf32_t, |
||||
|
fuse_conv_bias_t |
||||
|
}; |
@ -0,0 +1,10 @@ |
|||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
// Licensed under the MIT License.
|
||||
|
|
||||
|
#pragma once |
||||
|
|
||||
|
// CustomOpContext defines an interface allowing a custom op to access ep-specific resources.
|
||||
|
struct CustomOpContext { |
||||
|
CustomOpContext() = default; |
||||
|
virtual ~CustomOpContext() {}; |
||||
|
}; |
@ -0,0 +1,14 @@ |
|||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
// Licensed under the MIT License.
|
||||
|
|
||||
|
#pragma once |
||||
|
|
||||
|
enum ResourceOffset { |
||||
|
cpu_resource_offset = 0, |
||||
|
cuda_resource_offset = 10000, |
||||
|
dml_resource_offset = 20000, |
||||
|
rocm_resource_offset = 30000, |
||||
|
// offsets for other ort eps
|
||||
|
custom_ep_resource_offset = 10000000, |
||||
|
// offsets for customized eps
|
||||
|
}; |
File diff suppressed because it is too large
File diff suppressed because it is too large
File diff suppressed because it is too large
@ -0,0 +1,535 @@ |
|||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
// Licensed under the MIT License.
|
||||
|
|
||||
|
#pragma once |
||||
|
|
||||
|
#include <stdint.h> |
||||
|
#include <cmath> |
||||
|
#include <cstring> |
||||
|
#include <limits> |
||||
|
|
||||
|
namespace onnxruntime_float16 { |
||||
|
|
||||
|
namespace detail { |
||||
|
|
||||
|
enum class endian { |
||||
|
#if defined(_WIN32) |
||||
|
little = 0, |
||||
|
big = 1, |
||||
|
native = little, |
||||
|
#elif defined(__GNUC__) || defined(__clang__) |
||||
|
little = __ORDER_LITTLE_ENDIAN__, |
||||
|
big = __ORDER_BIG_ENDIAN__, |
||||
|
native = __BYTE_ORDER__, |
||||
|
#else |
||||
|
#error onnxruntime_float16::detail::endian is not implemented in this environment. |
||||
|
#endif |
||||
|
}; |
||||
|
|
||||
|
static_assert( |
||||
|
endian::native == endian::little || endian::native == endian::big, |
||||
|
"Only little-endian or big-endian native byte orders are supported."); |
||||
|
|
||||
|
} // namespace detail
|
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Shared implementation between public and internal classes. CRTP pattern.
|
||||
|
/// </summary>
|
||||
|
template <class Derived> |
||||
|
struct Float16Impl { |
||||
|
protected: |
||||
|
/// <summary>
|
||||
|
/// Converts from float to uint16_t float16 representation
|
||||
|
/// </summary>
|
||||
|
/// <param name="v"></param>
|
||||
|
/// <returns></returns>
|
||||
|
constexpr static uint16_t ToUint16Impl(float v) noexcept; |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Converts float16 to float
|
||||
|
/// </summary>
|
||||
|
/// <returns>float representation of float16 value</returns>
|
||||
|
float ToFloatImpl() const noexcept; |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Creates an instance that represents absolute value.
|
||||
|
/// </summary>
|
||||
|
/// <returns>Absolute value</returns>
|
||||
|
uint16_t AbsImpl() const noexcept { |
||||
|
return static_cast<uint16_t>(val & ~kSignMask); |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Creates a new instance with the sign flipped.
|
||||
|
/// </summary>
|
||||
|
/// <returns>Flipped sign instance</returns>
|
||||
|
uint16_t NegateImpl() const noexcept { |
||||
|
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask); |
||||
|
} |
||||
|
|
||||
|
public: |
||||
|
// uint16_t special values
|
||||
|
static constexpr uint16_t kSignMask = 0x8000U; |
||||
|
static constexpr uint16_t kBiasedExponentMask = 0x7C00U; |
||||
|
static constexpr uint16_t kPositiveInfinityBits = 0x7C00U; |
||||
|
static constexpr uint16_t kNegativeInfinityBits = 0xFC00U; |
||||
|
static constexpr uint16_t kPositiveQNaNBits = 0x7E00U; |
||||
|
static constexpr uint16_t kNegativeQNaNBits = 0xFE00U; |
||||
|
static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
|
||||
|
static constexpr uint16_t kOneBits = 0x3C00U; |
||||
|
static constexpr uint16_t kMinusOneBits = 0xBC00U; |
||||
|
|
||||
|
uint16_t val{0}; |
||||
|
|
||||
|
Float16Impl() = default; |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Checks if the value is negative
|
||||
|
/// </summary>
|
||||
|
/// <returns>true if negative</returns>
|
||||
|
bool IsNegative() const noexcept { |
||||
|
return static_cast<int16_t>(val) < 0; |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value is NaN
|
||||
|
/// </summary>
|
||||
|
/// <returns>true if NaN</returns>
|
||||
|
bool IsNaN() const noexcept { |
||||
|
return AbsImpl() > kPositiveInfinityBits; |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value is finite
|
||||
|
/// </summary>
|
||||
|
/// <returns>true if finite</returns>
|
||||
|
bool IsFinite() const noexcept { |
||||
|
return AbsImpl() < kPositiveInfinityBits; |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value represents positive infinity.
|
||||
|
/// </summary>
|
||||
|
/// <returns>true if positive infinity</returns>
|
||||
|
bool IsPositiveInfinity() const noexcept { |
||||
|
return val == kPositiveInfinityBits; |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value represents negative infinity
|
||||
|
/// </summary>
|
||||
|
/// <returns>true if negative infinity</returns>
|
||||
|
bool IsNegativeInfinity() const noexcept { |
||||
|
return val == kNegativeInfinityBits; |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value is either positive or negative infinity.
|
||||
|
/// </summary>
|
||||
|
/// <returns>True if absolute value is infinity</returns>
|
||||
|
bool IsInfinity() const noexcept { |
||||
|
return AbsImpl() == kPositiveInfinityBits; |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value is NaN or zero. Useful for comparisons.
|
||||
|
/// </summary>
|
||||
|
/// <returns>True if NaN or zero.</returns>
|
||||
|
bool IsNaNOrZero() const noexcept { |
||||
|
auto abs = AbsImpl(); |
||||
|
return (abs == 0 || abs > kPositiveInfinityBits); |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
|
||||
|
/// </summary>
|
||||
|
/// <returns>True if so</returns>
|
||||
|
bool IsNormal() const noexcept { |
||||
|
auto abs = AbsImpl(); |
||||
|
return (abs < kPositiveInfinityBits) // is finite
|
||||
|
&& (abs != 0) // is not zero
|
||||
|
&& ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
|
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value is subnormal (denormal).
|
||||
|
/// </summary>
|
||||
|
/// <returns>True if so</returns>
|
||||
|
bool IsSubnormal() const noexcept { |
||||
|
auto abs = AbsImpl(); |
||||
|
return (abs < kPositiveInfinityBits) // is finite
|
||||
|
&& (abs != 0) // is not zero
|
||||
|
&& ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
|
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Creates an instance that represents absolute value.
|
||||
|
/// </summary>
|
||||
|
/// <returns>Absolute value</returns>
|
||||
|
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Creates a new instance with the sign flipped.
|
||||
|
/// </summary>
|
||||
|
/// <returns>Flipped sign instance</returns>
|
||||
|
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
||||
|
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
||||
|
/// and therefore equivalent, if the resulting value is still zero.
|
||||
|
/// </summary>
|
||||
|
/// <param name="lhs">first value</param>
|
||||
|
/// <param name="rhs">second value</param>
|
||||
|
/// <returns>True if both arguments represent zero</returns>
|
||||
|
static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept { |
||||
|
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0; |
||||
|
} |
||||
|
|
||||
|
bool operator==(const Float16Impl& rhs) const noexcept { |
||||
|
if (IsNaN() || rhs.IsNaN()) { |
||||
|
// IEEE defines that NaN is not equal to anything, including itself.
|
||||
|
return false; |
||||
|
} |
||||
|
return val == rhs.val; |
||||
|
} |
||||
|
|
||||
|
bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); } |
||||
|
|
||||
|
bool operator<(const Float16Impl& rhs) const noexcept { |
||||
|
if (IsNaN() || rhs.IsNaN()) { |
||||
|
// IEEE defines that NaN is unordered with respect to everything, including itself.
|
||||
|
return false; |
||||
|
} |
||||
|
|
||||
|
const bool left_is_negative = IsNegative(); |
||||
|
if (left_is_negative != rhs.IsNegative()) { |
||||
|
// When the signs of left and right differ, we know that left is less than right if it is
|
||||
|
// the negative value. The exception to this is if both values are zero, in which case IEEE
|
||||
|
// says they should be equal, even if the signs differ.
|
||||
|
return left_is_negative && !AreZero(*this, rhs); |
||||
|
} |
||||
|
return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); |
||||
|
} |
||||
|
}; |
||||
|
|
||||
|
// The following Float16_t conversions are based on the code from
|
||||
|
// Eigen library.
|
||||
|
|
||||
|
// The conversion routines are Copyright (c) Fabian Giesen, 2016.
|
||||
|
// The original license follows:
|
||||
|
//
|
||||
|
// Copyright (c) Fabian Giesen, 2016
|
||||
|
// All rights reserved.
|
||||
|
// Redistribution and use in source and binary forms, with or without
|
||||
|
// modification, are permitted.
|
||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
|
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
|
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
|
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
|
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
|
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
|
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
|
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
|
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
|
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
||||
|
namespace detail { |
||||
|
union float32_bits { |
||||
|
unsigned int u; |
||||
|
float f; |
||||
|
}; |
||||
|
} // namespace detail
|
||||
|
|
||||
|
template <class Derived> |
||||
|
inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept { |
||||
|
detail::float32_bits f{}; |
||||
|
f.f = v; |
||||
|
|
||||
|
constexpr detail::float32_bits f32infty = {255 << 23}; |
||||
|
constexpr detail::float32_bits f16max = {(127 + 16) << 23}; |
||||
|
constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; |
||||
|
constexpr unsigned int sign_mask = 0x80000000u; |
||||
|
uint16_t val = static_cast<uint16_t>(0x0u); |
||||
|
|
||||
|
unsigned int sign = f.u & sign_mask; |
||||
|
f.u ^= sign; |
||||
|
|
||||
|
// NOTE all the integer compares in this function can be safely
|
||||
|
// compiled into signed compares since all operands are below
|
||||
|
// 0x80000000. Important if you want fast straight SSE2 code
|
||||
|
// (since there's no unsigned PCMPGTD).
|
||||
|
|
||||
|
if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
|
||||
|
val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
|
||||
|
} else { // (De)normalized number or zero
|
||||
|
if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
|
||||
|
// use a magic value to align our 10 mantissa bits at the bottom of
|
||||
|
// the float. as long as FP addition is round-to-nearest-even this
|
||||
|
// just works.
|
||||
|
f.f += denorm_magic.f; |
||||
|
|
||||
|
// and one integer subtract of the bias later, we have our final float!
|
||||
|
val = static_cast<uint16_t>(f.u - denorm_magic.u); |
||||
|
} else { |
||||
|
unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
|
||||
|
|
||||
|
// update exponent, rounding bias part 1
|
||||
|
// Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
|
||||
|
// without arithmetic overflow.
|
||||
|
f.u += 0xc8000fffU; |
||||
|
// rounding bias part 2
|
||||
|
f.u += mant_odd; |
||||
|
// take the bits!
|
||||
|
val = static_cast<uint16_t>(f.u >> 13); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
val |= static_cast<uint16_t>(sign >> 16); |
||||
|
return val; |
||||
|
} |
||||
|
|
||||
|
template <class Derived> |
||||
|
inline float Float16Impl<Derived>::ToFloatImpl() const noexcept { |
||||
|
constexpr detail::float32_bits magic = {113 << 23}; |
||||
|
constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
|
||||
|
detail::float32_bits o{}; |
||||
|
|
||||
|
o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
|
||||
|
unsigned int exp = shifted_exp & o.u; // just the exponent
|
||||
|
o.u += (127 - 15) << 23; // exponent adjust
|
||||
|
|
||||
|
// handle exponent special cases
|
||||
|
if (exp == shifted_exp) { // Inf/NaN?
|
||||
|
o.u += (128 - 16) << 23; // extra exp adjust
|
||||
|
} else if (exp == 0) { // Zero/Denormal?
|
||||
|
o.u += 1 << 23; // extra exp adjust
|
||||
|
o.f -= magic.f; // re-normalize
|
||||
|
} |
||||
|
|
||||
|
// Attempt to workaround the Internal Compiler Error on ARM64
|
||||
|
// for bitwise | operator, including std::bitset
|
||||
|
#if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC) |
||||
|
if (IsNegative()) { |
||||
|
return -o.f; |
||||
|
} |
||||
|
#else |
||||
|
// original code:
|
||||
|
o.u |= (val & 0x8000U) << 16U; // sign bit
|
||||
|
#endif |
||||
|
return o.f; |
||||
|
} |
||||
|
|
||||
|
/// Shared implementation between public and internal classes. CRTP pattern.
|
||||
|
template <class Derived> |
||||
|
struct BFloat16Impl { |
||||
|
protected: |
||||
|
/// <summary>
|
||||
|
/// Converts from float to uint16_t float16 representation
|
||||
|
/// </summary>
|
||||
|
/// <param name="v"></param>
|
||||
|
/// <returns></returns>
|
||||
|
static uint16_t ToUint16Impl(float v) noexcept; |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Converts bfloat16 to float
|
||||
|
/// </summary>
|
||||
|
/// <returns>float representation of bfloat16 value</returns>
|
||||
|
float ToFloatImpl() const noexcept; |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Creates an instance that represents absolute value.
|
||||
|
/// </summary>
|
||||
|
/// <returns>Absolute value</returns>
|
||||
|
uint16_t AbsImpl() const noexcept { |
||||
|
return static_cast<uint16_t>(val & ~kSignMask); |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Creates a new instance with the sign flipped.
|
||||
|
/// </summary>
|
||||
|
/// <returns>Flipped sign instance</returns>
|
||||
|
uint16_t NegateImpl() const noexcept { |
||||
|
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask); |
||||
|
} |
||||
|
|
||||
|
public: |
||||
|
// uint16_t special values
|
||||
|
static constexpr uint16_t kSignMask = 0x8000U; |
||||
|
static constexpr uint16_t kBiasedExponentMask = 0x7F80U; |
||||
|
static constexpr uint16_t kPositiveInfinityBits = 0x7F80U; |
||||
|
static constexpr uint16_t kNegativeInfinityBits = 0xFF80U; |
||||
|
static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U; |
||||
|
static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U; |
||||
|
static constexpr uint16_t kMaxValueBits = 0x7F7FU; |
||||
|
static constexpr uint16_t kRoundToNearest = 0x7FFFU; |
||||
|
static constexpr uint16_t kOneBits = 0x3F80U; |
||||
|
static constexpr uint16_t kMinusOneBits = 0xBF80U; |
||||
|
|
||||
|
uint16_t val{0}; |
||||
|
|
||||
|
BFloat16Impl() = default; |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Checks if the value is negative
|
||||
|
/// </summary>
|
||||
|
/// <returns>true if negative</returns>
|
||||
|
bool IsNegative() const noexcept { |
||||
|
return static_cast<int16_t>(val) < 0; |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value is NaN
|
||||
|
/// </summary>
|
||||
|
/// <returns>true if NaN</returns>
|
||||
|
bool IsNaN() const noexcept { |
||||
|
return AbsImpl() > kPositiveInfinityBits; |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value is finite
|
||||
|
/// </summary>
|
||||
|
/// <returns>true if finite</returns>
|
||||
|
bool IsFinite() const noexcept { |
||||
|
return AbsImpl() < kPositiveInfinityBits; |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value represents positive infinity.
|
||||
|
/// </summary>
|
||||
|
/// <returns>true if positive infinity</returns>
|
||||
|
bool IsPositiveInfinity() const noexcept { |
||||
|
return val == kPositiveInfinityBits; |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value represents negative infinity
|
||||
|
/// </summary>
|
||||
|
/// <returns>true if negative infinity</returns>
|
||||
|
bool IsNegativeInfinity() const noexcept { |
||||
|
return val == kNegativeInfinityBits; |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value is either positive or negative infinity.
|
||||
|
/// </summary>
|
||||
|
/// <returns>True if absolute value is infinity</returns>
|
||||
|
bool IsInfinity() const noexcept { |
||||
|
return AbsImpl() == kPositiveInfinityBits; |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value is NaN or zero. Useful for comparisons.
|
||||
|
/// </summary>
|
||||
|
/// <returns>True if NaN or zero.</returns>
|
||||
|
bool IsNaNOrZero() const noexcept { |
||||
|
auto abs = AbsImpl(); |
||||
|
return (abs == 0 || abs > kPositiveInfinityBits); |
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
|
||||
|
/// </summary>
|
||||
|
/// <returns>True if so</returns>
|
||||
|
bool IsNormal() const noexcept { |
||||
|
auto abs = AbsImpl(); |
||||
|
return (abs < kPositiveInfinityBits) // is finite
|
||||
|
&& (abs != 0) // is not zero
|
||||
|
&& ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
|
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Tests if the value is subnormal (denormal).
|
||||
|
/// </summary>
|
||||
|
/// <returns>True if so</returns>
|
||||
|
bool IsSubnormal() const noexcept { |
||||
|
auto abs = AbsImpl(); |
||||
|
return (abs < kPositiveInfinityBits) // is finite
|
||||
|
&& (abs != 0) // is not zero
|
||||
|
&& ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
|
||||
|
} |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Creates an instance that represents absolute value.
|
||||
|
/// </summary>
|
||||
|
/// <returns>Absolute value</returns>
|
||||
|
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// Creates a new instance with the sign flipped.
|
||||
|
/// </summary>
|
||||
|
/// <returns>Flipped sign instance</returns>
|
||||
|
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
||||
|
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
||||
|
/// and therefore equivalent, if the resulting value is still zero.
|
||||
|
/// </summary>
|
||||
|
/// <param name="lhs">first value</param>
|
||||
|
/// <param name="rhs">second value</param>
|
||||
|
/// <returns>True if both arguments represent zero</returns>
|
||||
|
static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept { |
||||
|
// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
||||
|
// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
||||
|
// and therefore equivalent, if the resulting value is still zero.
|
||||
|
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0; |
||||
|
} |
||||
|
}; |
||||
|
|
||||
|
template <class Derived> |
||||
|
inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept { |
||||
|
uint16_t result; |
||||
|
if (std::isnan(v)) { |
||||
|
result = kPositiveQNaNBits; |
||||
|
} else { |
||||
|
auto get_msb_half = [](float fl) { |
||||
|
uint16_t result; |
||||
|
#ifdef __cpp_if_constexpr |
||||
|
if constexpr (detail::endian::native == detail::endian::little) { |
||||
|
#else |
||||
|
if (detail::endian::native == detail::endian::little) { |
||||
|
#endif |
||||
|
std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t)); |
||||
|
} else { |
||||
|
std::memcpy(&result, &fl, sizeof(uint16_t)); |
||||
|
} |
||||
|
return result; |
||||
|
}; |
||||
|
|
||||
|
uint16_t upper_bits = get_msb_half(v); |
||||
|
union { |
||||
|
uint32_t U32; |
||||
|
float F32; |
||||
|
}; |
||||
|
F32 = v; |
||||
|
U32 += (upper_bits & 1) + kRoundToNearest; |
||||
|
result = get_msb_half(F32); |
||||
|
} |
||||
|
return result; |
||||
|
} |
||||
|
|
||||
|
template <class Derived> |
||||
|
inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept { |
||||
|
if (IsNaN()) { |
||||
|
return std::numeric_limits<float>::quiet_NaN(); |
||||
|
} |
||||
|
float result; |
||||
|
char* const first = reinterpret_cast<char*>(&result); |
||||
|
char* const second = first + sizeof(uint16_t); |
||||
|
#ifdef __cpp_if_constexpr |
||||
|
if constexpr (detail::endian::native == detail::endian::little) { |
||||
|
#else |
||||
|
if (detail::endian::native == detail::endian::little) { |
||||
|
#endif |
||||
|
std::memset(first, 0, sizeof(uint16_t)); |
||||
|
std::memcpy(second, &val, sizeof(uint16_t)); |
||||
|
} else { |
||||
|
std::memcpy(first, &val, sizeof(uint16_t)); |
||||
|
std::memset(second, 0, sizeof(uint16_t)); |
||||
|
} |
||||
|
return result; |
||||
|
} |
||||
|
|
||||
|
} // namespace onnxruntime_float16
|
File diff suppressed because it is too large
@ -0,0 +1,731 @@ |
|||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
// Licensed under the MIT License.
|
||||
|
|
||||
|
// This file contains the training c apis.
|
||||
|
|
||||
|
#pragma once |
||||
|
#include <stdbool.h> |
||||
|
#include "onnxruntime_c_api.h" |
||||
|
|
||||
|
/** \page training_c_cpp_api Training C & C++ APIs
|
||||
|
* |
||||
|
* Training C and C++ APIs are an extension of the \ref c_cpp_api "onnxruntime core C and C++ APIs" and should be used in conjunction with them. |
||||
|
* |
||||
|
* In order to train a model with onnxruntime, the following training artifacts must be generated: |
||||
|
* - The training onnx model |
||||
|
* - The checkpoint file |
||||
|
* - The optimizer onnx model |
||||
|
* - The eval onnx model model (optional) |
||||
|
* |
||||
|
* These training artifacts can be generated as part of an offline step using the python [utilities](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md) made available in the `onnxruntime-training` python package.
|
||||
|
* |
||||
|
* After these artifacts have been generated, the C and C++ utilities listed in this documentation can be leveraged to perform training. |
||||
|
* |
||||
|
* If any problem is encountered, please create an [issue](https://github.com/microsoft/onnxruntime/issues/new) with your scenario and requirements, and we will be sure to respond and follow up on the request.
|
||||
|
* |
||||
|
* <h1>Training C API</h1> |
||||
|
* |
||||
|
* ::OrtTrainingApi - Training C API functions. |
||||
|
* |
||||
|
* This C structure contains functions that enable users to perform training with onnxruntime. |
||||
|
* |
||||
|
* _Sample Code_: |
||||
|
* |
||||
|
* ```c |
||||
|
* #include <onnxruntime_training_api.h> |
||||
|
* |
||||
|
* OrtApi* g_ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); |
||||
|
* OrtTrainingApi* g_ort_training_api = g_ort_api->GetTrainingApi(ORT_API_VERSION); |
||||
|
* |
||||
|
* OrtEnv* env = NULL; |
||||
|
* g_ort_api->CreateEnv(logging_level, logid, &env); |
||||
|
* OrtSessionOptions* session_options = NULL; |
||||
|
* g_ort_api->CreateSessionOptions(&session_options); |
||||
|
* |
||||
|
* OrtCheckpointState* state = NULL; |
||||
|
* g_ort_training_api->LoadCheckpoint(path_to_checkpoint, &state); |
||||
|
* |
||||
|
* OrtTrainingSession* training_session = NULL; |
||||
|
* g_ort_training_api->CreateTrainingSession(env, session_options, training_model_path, |
||||
|
* state, eval_model_path, optimizer_model_path, |
||||
|
* &training_session); |
||||
|
* // Training loop
|
||||
|
* { |
||||
|
* g_ort_training_api->TrainStep(...); |
||||
|
* g_ort_training_api->OptimizerStep(...); |
||||
|
* g_ort_training_api->LazyResetGrad(...); |
||||
|
* } |
||||
|
* |
||||
|
* g_ort_training_api->ExportModelForInferencing(training_session, inference_model_path, ...); |
||||
|
* g_ort_training_api->SaveCheckpoint(state, path_to_checkpoint, false); |
||||
|
* |
||||
|
* g_ort_training_api->ReleaseTrainingSession(training_session); |
||||
|
* g_ort_training_api->ReleaseCheckpointState(state); |
||||
|
* ``` |
||||
|
* |
||||
|
* > **Note** |
||||
|
* > The ::OrtCheckpointState contains the entire training state that the ::OrtTrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::OrtCheckpointState instance must outlive the lifetime of the ::OrtTrainingSession instance. |
||||
|
* |
||||
|
* <h1>Training C++ API</h1> |
||||
|
* |
||||
|
* @ref TrainingCpp - Training C++ API classes and functions. |
||||
|
* |
||||
|
* These C++ classes and functions enable users to perform training with onnxruntime. |
||||
|
* |
||||
|
* _Sample Code_: |
||||
|
* |
||||
|
* ```cc |
||||
|
* #include <onnxruntime_training_cxx_api.h> |
||||
|
* |
||||
|
* Ort::Env env; |
||||
|
* Ort::SessionOptions session_options; |
||||
|
* |
||||
|
* auto state = Ort::CheckpointState::LoadCheckpoint(path_to_checkpoint); |
||||
|
* auto training_session = Ort::TrainingSession(env, session_options, state, training_model_path, |
||||
|
* eval_model_path, optimizer_model_path); |
||||
|
* |
||||
|
* // Training Loop
|
||||
|
* { |
||||
|
* training_session.TrainStep(...); |
||||
|
* training_session.OptimizerStep(...); |
||||
|
* training_session.LazyResetGrad(...); |
||||
|
* } |
||||
|
* |
||||
|
* training_session->ExportModelForInferencing(inference_model_path, ...); |
||||
|
* Ort::CheckpointState::SaveCheckpoint(state, path_to_checkpoint, false); |
||||
|
* ``` |
||||
|
* > **Note** |
||||
|
* > The ::Ort::CheckpointState contains the entire training state that the ::Ort::TrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::Ort::CheckpointState instance must outlive the lifetime of the ::Ort::TrainingSession instance. |
||||
|
*/ |
||||
|
|
||||
|
/** @defgroup TrainingC Ort Training C API
|
||||
|
* @{ |
||||
|
*/ |
||||
|
ORT_RUNTIME_CLASS(TrainingSession); // Type that enables performing training for the given user models.
|
||||
|
ORT_RUNTIME_CLASS(CheckpointState); // Type that holds the training states for the training session.
|
||||
|
|
||||
|
/** \brief Type of property to be added to or returned from the ::OrtCheckpointState.
|
||||
|
*/ |
||||
|
typedef enum OrtPropertyType { |
||||
|
OrtIntProperty = 0, |
||||
|
OrtFloatProperty = 1, |
||||
|
OrtStringProperty = 2, |
||||
|
} OrtPropertyType; |
||||
|
|
||||
|
/** \brief The Training C API that holds onnxruntime training function pointers
|
||||
|
* |
||||
|
* All the Training C API functions are defined inside this structure as pointers to functions. |
||||
|
* Call OrtApi::GetTrainingApi to get a pointer to this struct. |
||||
|
* |
||||
|
* \nosubgrouping |
||||
|
*/ |
||||
|
struct OrtTrainingApi { |
||||
|
/// \name Accessing The Training Session State
|
||||
|
/// @{
|
||||
|
|
||||
|
/** \brief Load a checkpoint state from a file on disk into checkpoint_state.
|
||||
|
* |
||||
|
* This function will parse a checkpoint file, pull relevant data and load the training |
||||
|
* state into the checkpoint_state. This checkpoint state can then be used to create the |
||||
|
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training |
||||
|
* session will resume training from the given checkpoint state. |
||||
|
* \note Note that the training session created with a checkpoint state uses this state to store the entire |
||||
|
* training state (including model parameters, its gradients, the optimizer states and the properties). |
||||
|
* As a result, it is required that the checkpoint state outlive the lifetime of the training session. |
||||
|
* \note Note that the checkpoint file can be either the complete checkpoint or the nominal checkpoint. |
||||
|
* |
||||
|
* \param[in] checkpoint_path Path to the checkpoint file |
||||
|
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, |
||||
|
_Outptr_ OrtCheckpointState** checkpoint_state); |
||||
|
|
||||
|
/** \brief Save the given state to a checkpoint file on disk.
|
||||
|
* |
||||
|
* This function serializes the provided checkpoint state to a file on disk. |
||||
|
* This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume |
||||
|
* training from this snapshot of the state. |
||||
|
* |
||||
|
* \param[in] checkpoint_state The checkpoint state to save. |
||||
|
* \param[in] checkpoint_path Path to the checkpoint file. |
||||
|
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(SaveCheckpoint, _In_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* checkpoint_path, |
||||
|
const bool include_optimizer_state); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Implementing The Training Loop
|
||||
|
/// @{
|
||||
|
/** \brief Create a training session that can be used to begin or resume training.
|
||||
|
* |
||||
|
* This function creates a training session based on the env and session options provided that can |
||||
|
* begin or resume training from a given checkpoint state for the given onnx models. |
||||
|
* The checkpoint state represents the parameters of the training session which will be moved |
||||
|
* to the device specified by the user through the session options (if necessary). |
||||
|
* The training session requires four training artifacts |
||||
|
* - The training onnx model |
||||
|
* - The evaluation onnx model (optional) |
||||
|
* - The optimizer onnx model |
||||
|
* - The checkpoint file |
||||
|
* |
||||
|
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
|
||||
|
* |
||||
|
* \param[in] env Environment to be used for the training session. |
||||
|
* \param[in] options Session options that the user can customize for this training session. |
||||
|
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training. |
||||
|
* \param[in] train_model_path Model to be used to perform training. |
||||
|
* \param[in] eval_model_path Model to be used to perform evaluation. |
||||
|
* \param[in] optimizer_model_path Model to be used to perform gradient descent. |
||||
|
* \param[out] out Created training session. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, |
||||
|
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, |
||||
|
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, |
||||
|
_Outptr_result_maybenull_ OrtTrainingSession** out); |
||||
|
|
||||
|
/** \brief Create a training session that can be used to begin or resume training.
|
||||
|
* This api provides a way to load all the training artifacts from buffers instead of files. |
||||
|
* |
||||
|
* \param[in] env Environment to be used for the training session. |
||||
|
* \param[in] options Session options that the user can customize for this training session. |
||||
|
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training. |
||||
|
* \param[in] train_model_data Buffer containing the model data to be used to perform training |
||||
|
* \param[in] train_data_length Length of the buffer containing train_model_data |
||||
|
* \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation |
||||
|
* \param[in] eval_data_length Length of the buffer containing eval_model_data |
||||
|
* \param[in] optim_model_data Buffer containing the model data to be used to perform weight update |
||||
|
* \param[in] optim_data_length Length of the buffer containing optim_model_data |
||||
|
* \param[out] out Created training session. |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, |
||||
|
_In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, |
||||
|
_In_ const void* train_model_data, size_t train_data_length, |
||||
|
_In_ const void* eval_model_data, size_t eval_data_length, |
||||
|
_In_ const void* optim_model_data, size_t optim_data_length, |
||||
|
_Outptr_result_maybenull_ OrtTrainingSession** out); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Model IO Information
|
||||
|
/// @{
|
||||
|
|
||||
|
/** \brief Retrieves the number of user outputs in the training model.
|
||||
|
* |
||||
|
* This function returns the number of outputs of the training model so that the user can |
||||
|
* allocate space for the number of outputs when OrtTrainingApi::TrainStep is invoked. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[out] out Number of user outputs in the training model. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); |
||||
|
|
||||
|
/** \brief Retrieves the number of user outputs in the eval model.
|
||||
|
* |
||||
|
* This function returns the number of outputs of the eval model so that the user can |
||||
|
* allocate space for the number of outputs when OrtTrainingApi::EvalStep is invoked. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[out] out Number of user outputs in the eval model. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(TrainingSessionGetEvalModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); |
||||
|
|
||||
|
/** \brief Retrieves the names of user outputs in the training model.
|
||||
|
* |
||||
|
* This function returns the names of outputs of the training model that can be associated with the OrtValue(s) |
||||
|
* returned by the OrtTrainingApi::TrainStep function. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[in] index Index of the output name requested. |
||||
|
* \param[in] allocator Allocator to use to allocate the memory for the name. |
||||
|
* \param[out] output Name of the training model output at the given index. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output); |
||||
|
|
||||
|
/** \brief Retrieves the names of user outputs in the eval model.
|
||||
|
* |
||||
|
* This function returns the names of outputs of the eval model that can be associated with the OrtValue(s) returned |
||||
|
* by the OrtTrainingApi::EvalStep function. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[in] index Index of the output name requested. |
||||
|
* \param[in] allocator Allocator to use to allocate the memory for the name. |
||||
|
* \param[out] output Name of the eval model output at the given index. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(TrainingSessionGetEvalModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Implementing The Training Loop
|
||||
|
/// @{
|
||||
|
|
||||
|
/** \brief Reset the gradients of all trainable parameters to zero lazily.
|
||||
|
* |
||||
|
* This function sets the internal state of the training session such that the gradients of the trainable |
||||
|
* parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are |
||||
|
* computed on the next invocation of the next OrtTrainingApi::TrainStep. |
||||
|
* |
||||
|
* \param[in] session The `this` pointer to the training session. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(LazyResetGrad, _Inout_ OrtTrainingSession* session); |
||||
|
|
||||
|
/** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
|
||||
|
* |
||||
|
* This function performs a training step that computes the outputs of the training model and the gradients |
||||
|
* of the trainable parameters for the given inputs. The train step is performed based on the training model |
||||
|
* that was provided to the training session. |
||||
|
* The OrtTrainingApi::TrainStep is equivalent of running forward propagation and backward propagation in a single |
||||
|
* step. |
||||
|
* The gradients computed are stored inside the training session state so they can be later consumed |
||||
|
* by the OrtTrainingApi::OptimizerStep function. |
||||
|
* The gradients can be lazily reset by invoking the OrtTrainingApi::LazyResetGrad function. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[in] run_options Run options for this training step. |
||||
|
* \param[in] inputs_len Number of user inputs to the training model. |
||||
|
* \param[in] inputs The user inputs to the training model. |
||||
|
* \param[in] outputs_len Number of user outputs expected from this training step. |
||||
|
* \param[out] outputs User outputs computed by train step. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, |
||||
|
_In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, |
||||
|
_In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); |
||||
|
|
||||
|
/** \brief Computes the outputs for the eval model for the given inputs
|
||||
|
* |
||||
|
* This function performs an eval step that computes the outputs of the eval model for the given inputs. |
||||
|
* The eval step is performed based on the eval model that was provided to the training session. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[in] run_options Run options for this eval step. |
||||
|
* \param[in] inputs_len Number of user inputs to the eval model. |
||||
|
* \param[in] inputs The user inputs to the eval model. |
||||
|
* \param[in] outputs_len Number of user outputs expected from this eval step. |
||||
|
* \param[out] outputs User outputs computed by eval step. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(EvalStep, _In_ const OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, |
||||
|
_In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, |
||||
|
_In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); |
||||
|
|
||||
|
/** \brief Sets the learning rate for this training session.
|
||||
|
* |
||||
|
* This function allows users to set the learning rate for the training session. The current |
||||
|
* learning rate is maintained by the training session and can be overwritten by invoking |
||||
|
* this function with the desired learning rate. This function should not be used when a valid |
||||
|
* learning rate scheduler is registered. It should be used either to set the learning rate |
||||
|
* derived from a custom learning rate scheduler or to set a constant learning rate to be used |
||||
|
* throughout the training session. |
||||
|
* \note Please note that this function does not set the initial learning rate that may be needed |
||||
|
* by the predefined learning rate schedulers. To set the initial learning rate for learning |
||||
|
* rate schedulers, please look at the function OrtTrainingApi::RegisterLinearLRScheduler. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[in] learning_rate Desired learning rate to be set. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate); |
||||
|
|
||||
|
/** \brief Gets the current learning rate for this training session.
|
||||
|
* |
||||
|
* This function allows users to get the learning rate for the training session. The current |
||||
|
* learning rate is maintained by the training session, and users can query it for the purpose |
||||
|
* of implementing their own learning rate schedulers. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[out] learning_rate Learning rate currently in use by the training session. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(GetLearningRate, _Inout_ OrtTrainingSession* sess, _Out_ float* learning_rate); |
||||
|
|
||||
|
/** \brief Performs the weight updates for the trainable parameters using the optimizer model.
|
||||
|
* |
||||
|
* This function performs the weight update step that updates the trainable parameters such that they |
||||
|
* take a step in the direction of their gradients (gradient descent). The optimizer step is performed |
||||
|
* based on the optimizer model that was provided to the training session. |
||||
|
* The updated parameters are stored inside the training state so that they can be used by the next |
||||
|
* OrtTrainingApi::TrainStep function call. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[in] run_options Run options for this optimizer step. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess, |
||||
|
_In_opt_ const OrtRunOptions* run_options); |
||||
|
|
||||
|
/** \brief Registers a linear learning rate scheduler for the training session.
|
||||
|
* |
||||
|
* Register a linear learning rate scheduler that decays the learning rate by linearly updated |
||||
|
* multiplicative factor from the initial learning rate set on the training session to 0. The decay |
||||
|
* is performed after the initial warm up phase where the learning rate is linearly incremented |
||||
|
* from 0 to the initial learning rate provided. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[in] warmup_step_count Warmup steps for LR warmup. |
||||
|
* \param[in] total_step_count Total step count. |
||||
|
* \param[in] initial_lr The initial learning rate to be used by the training session. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(RegisterLinearLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ const int64_t warmup_step_count, |
||||
|
_In_ const int64_t total_step_count, _In_ const float initial_lr); |
||||
|
|
||||
|
/** \brief Update the learning rate based on the registered learing rate scheduler.
|
||||
|
* |
||||
|
* Takes a scheduler step that updates the learning rate that is being used by the training session. |
||||
|
* This function should typically be called before invoking the optimizer step for each round, |
||||
|
* or as determined necessary to update the learning rate being used by the training session. |
||||
|
* \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this |
||||
|
* function. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Accessing The Training Session State
|
||||
|
/// @{
|
||||
|
/** \brief Retrieves the size of all the parameters.
|
||||
|
* |
||||
|
* Calculates the total number of primitive (datatype of the parameters) elements of all the parameters in the |
||||
|
* training state. |
||||
|
* When trainable_only argument is true, the size is calculated for trainable params only. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[out] out Size of all parameter elements. |
||||
|
* \param[in] trainable_only Whether to skip non-trainable parameters |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only); |
||||
|
|
||||
|
/** \brief Copy all parameters to a contiguous buffer held by the argument parameters_buffer
|
||||
|
* |
||||
|
* The parameters_buffer has to be of the size given by GetParametersSize api call, |
||||
|
* with matching setting for the argument trainable_only. All the target parameters must be of the same |
||||
|
* datatype. The OrtValue must be pre-allocated onto |
||||
|
* the desired device. This is a complementary function to OrtTrainingApi::CopyBufferToParameters. |
||||
|
* Parameter ordering is preserved. |
||||
|
* User is responsible for allocating and freeing the resources used by the parameters_buffer. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[in] trainable_only Whether to skip non-trainable parameters |
||||
|
* \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy onto. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(CopyParametersToBuffer, _Inout_ OrtTrainingSession* sess, |
||||
|
_Inout_ OrtValue* parameters_buffer, bool trainable_only); |
||||
|
|
||||
|
/** \brief Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state
|
||||
|
* |
||||
|
* The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call, |
||||
|
* with matching setting for trainable_only argument. All the target parameters must be of the same |
||||
|
* datatype. This is a complementary function to OrtTrainingApi::CopyParametersToBuffer |
||||
|
* and can be used to load updated buffer values onto the training state. |
||||
|
* Parameter ordering is preserved. |
||||
|
* User is responsible for allocating and freeing the resources used by the parameters_buffer. |
||||
|
* In case the training session was created with a nominal checkpoint, invoking this function is required |
||||
|
* to load the updated parameters onto the checkpoint to complete it. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[in] trainable_only Whether to skip non-trainable parameters |
||||
|
* \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy from. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(CopyBufferToParameters, _Inout_ OrtTrainingSession* sess, |
||||
|
_Inout_ OrtValue* parameters_buffer, bool trainable_only); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Release Training Resources
|
||||
|
/// @{
|
||||
|
|
||||
|
/** \brief Frees up the memory used up by the training session.
|
||||
|
* |
||||
|
* This function frees up any memory that was allocated in the training session. The training |
||||
|
* session can no longer be used after this call. |
||||
|
* |
||||
|
*/ |
||||
|
ORT_CLASS_RELEASE(TrainingSession); |
||||
|
|
||||
|
/** \brief Frees up the memory used up by the checkpoint state.
|
||||
|
* |
||||
|
* This function frees up any memory that was allocated in the checkpoint state. The checkpoint |
||||
|
* state can no longer be used after this call. |
||||
|
* \note Note that the checkpoint state must be released only after the training session has been released. |
||||
|
* |
||||
|
*/ |
||||
|
ORT_CLASS_RELEASE(CheckpointState); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Prepare For Inferencing
|
||||
|
/// @{
|
||||
|
/** \brief Export a model that can be used for inferencing.
|
||||
|
* |
||||
|
* If the training session was provided with an eval model, the training session can generate |
||||
|
* an inference model if it knows the inference graph outputs. The input inference graph outputs |
||||
|
* are used to prune the eval model so that the inference model's outputs align with the provided outputs. |
||||
|
* The exported model is saved at the path provided and can be used for inferencing with InferenceSession. |
||||
|
* \note Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession |
||||
|
* and expects that this path still be valid. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[in] inference_model_path Path where the inference model should be serialized to. |
||||
|
* \param[in] graph_outputs_len Size of the graph output names array. |
||||
|
* \param[in] graph_output_names Names of the outputs that are needed in the inference model. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess, |
||||
|
_In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len, |
||||
|
_In_reads_(graph_outputs_len) const char* const* graph_output_names); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Training Utilities
|
||||
|
/// @{
|
||||
|
/** \brief Sets the seed used for random number generation in Onnxruntime.
|
||||
|
* |
||||
|
* Use this function to generate reproducible results. It should be noted that completely reproducible |
||||
|
* results are not guaranteed. |
||||
|
* |
||||
|
* \param[in] seed The seed to be set. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(SetSeed, _In_ const int64_t seed); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Model IO Information
|
||||
|
/// @{
|
||||
|
/** \brief Retrieves the number of user inputs in the training model.
|
||||
|
* |
||||
|
* This function returns the number of inputs of the training model so that the user can accordingly |
||||
|
* allocate the OrtValue(s) provided to the OrtTrainingApi::TrainStep function. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[out] out Number of user inputs in the training model. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(TrainingSessionGetTrainingModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); |
||||
|
|
||||
|
/** \brief Retrieves the number of user inputs in the eval model.
|
||||
|
* |
||||
|
* This function returns the number of inputs of the eval model so that the user can accordingly |
||||
|
* allocate the OrtValue(s) provided to the OrtTrainingApi::EvalStep function. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[out] out Number of user inputs in the eval model. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(TrainingSessionGetEvalModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); |
||||
|
|
||||
|
/** \brief Retrieves the name of the user input at given index in the training model.
|
||||
|
* |
||||
|
* This function returns the names of inputs of the training model that can be associated with the |
||||
|
* OrtValue(s) provided to the OrtTrainingApi::TrainStep function. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[in] index The index of the training model input name requested. |
||||
|
* \param[in] allocator The allocator to use to allocate the memory for the requested name. |
||||
|
* \param[out] output Name of the user input for the training model at the given index. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(TrainingSessionGetTrainingModelInputName, _In_ const OrtTrainingSession* sess, size_t index, |
||||
|
_In_ OrtAllocator* allocator, _Outptr_ char** output); |
||||
|
|
||||
|
/** \brief Retrieves the name of the user input at given index in the eval model.
|
||||
|
* |
||||
|
* This function returns the names of inputs of the eval model that can be associated with the OrtValue(s) provided |
||||
|
* to the OrtTrainingApi::EvalStep function. |
||||
|
* |
||||
|
* \param[in] sess The `this` pointer to the training session. |
||||
|
* \param[in] index The index of the eval model input name requested. |
||||
|
* \param[in] allocator The allocator to use to allocate the memory for the requested name. |
||||
|
* \param[out] output Name of the user input for the eval model at the given index. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(TrainingSessionGetEvalModelInputName, _In_ const OrtTrainingSession* sess, size_t index, |
||||
|
_In_ OrtAllocator* allocator, _Outptr_ char** output); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Accessing The Training Session State
|
||||
|
/// @{
|
||||
|
|
||||
|
/** \brief Adds or updates the given property to/in the checkpoint state.
|
||||
|
* |
||||
|
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint |
||||
|
* state by the user by calling this function with the corresponding property name and value. |
||||
|
* The given property name must be unique to be able to successfully add the property. |
||||
|
* |
||||
|
* \param[in] checkpoint_state The checkpoint state which should hold the property. |
||||
|
* \param[in] property_name Name of the property being added or updated. |
||||
|
* \param[in] property_type Type of the property associated with the given name. |
||||
|
* \param[in] property_value Property value associated with the given name. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state, |
||||
|
_In_ const char* property_name, _In_ enum OrtPropertyType property_type, |
||||
|
_In_ void* property_value); |
||||
|
|
||||
|
/** \brief Gets the property value associated with the given name from the checkpoint state.
|
||||
|
* |
||||
|
* Gets the property value from an existing entry in the checkpoint state. The property must |
||||
|
* exist in the checkpoint state to be able to retrieve it successfully. |
||||
|
* |
||||
|
* \param[in] checkpoint_state The checkpoint state that is currently holding the property. |
||||
|
* \param[in] property_name Name of the property being retrieved. |
||||
|
* \param[in] allocator Allocator used to allocate the memory for the property_value. |
||||
|
* \param[out] property_type Type of the property associated with the given name. |
||||
|
* \param[out] property_value Property value associated with the given name. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state, |
||||
|
_In_ const char* property_name, _Inout_ OrtAllocator* allocator, |
||||
|
_Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Accessing The Training Session State
|
||||
|
/// @{
|
||||
|
|
||||
|
/** \brief Load a checkpoint state from a buffer into checkpoint_state.
|
||||
|
* |
||||
|
* This function will parse a checkpoint bytes buffer, pull relevant data and load the training |
||||
|
* state into the checkpoint_state. This checkpoint state can then be used to create the |
||||
|
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training |
||||
|
* session will resume training from the given checkpoint state. |
||||
|
* \note Note that the training session created with a checkpoint state uses this state to store the entire |
||||
|
* training state (including model parameters, its gradients, the optimizer states and the properties). |
||||
|
* As a result, it is required that the checkpoint state outlive the lifetime of the training session. |
||||
|
* |
||||
|
* \param[in] checkpoint_buffer Path to the checkpoint bytes buffer. |
||||
|
* \param[in] num_bytes Number of bytes in the checkpoint buffer. |
||||
|
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, |
||||
|
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); |
||||
|
|
||||
|
/** \brief Retrieves the type and shape information of the parameter associated with the given parameter name.
|
||||
|
* |
||||
|
* This function retrieves the type and shape of the parameter associated with the given parameter name. |
||||
|
* The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully. |
||||
|
* |
||||
|
* \param[in] checkpoint_state The checkpoint state. |
||||
|
* \param[in] parameter_name Name of the parameter being retrieved. |
||||
|
* \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, |
||||
|
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); |
||||
|
|
||||
|
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
|
||||
|
* |
||||
|
* This function updates a model parameter in the checkpoint state with the given parameter data. |
||||
|
* The training session must be already created with the checkpoint state that contains the parameter |
||||
|
* being updated. The given parameter is copied over to the registered device for the training session. |
||||
|
* The parameter must exist in the checkpoint state to be able to update it successfully. |
||||
|
* |
||||
|
* \param[in] checkpoint_state The checkpoint state. |
||||
|
* \param[in] parameter_name Name of the parameter being updated. |
||||
|
* \param[in] parameter The parameter data that should replace the existing parameter data. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, |
||||
|
_In_ const char* parameter_name, _In_ OrtValue* parameter); |
||||
|
|
||||
|
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
|
||||
|
* |
||||
|
* This function retrieves the model parameter data from the checkpoint state for the given parameter name. |
||||
|
* The parameter is copied over and returned as an OrtValue. The training session must be already created |
||||
|
* with the checkpoint state that contains the parameter being retrieved. |
||||
|
* The parameter must exist in the checkpoint state to be able to retrieve it successfully. |
||||
|
* |
||||
|
* \param[in] checkpoint_state The checkpoint state. |
||||
|
* \param[in] parameter_name Name of the parameter being retrieved. |
||||
|
* \param[in] allocator Allocator used to allocate the memory for the parameter. |
||||
|
* \param[out] parameter The parameter data that is retrieved from the checkpoint state. |
||||
|
* |
||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value |
||||
|
* |
||||
|
*/ |
||||
|
ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, |
||||
|
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, |
||||
|
_Outptr_ OrtValue** parameter); |
||||
|
|
||||
|
/// @}
|
||||
|
}; |
||||
|
|
||||
|
typedef struct OrtTrainingApi OrtTrainingApi; |
||||
|
|
||||
|
/// @}
|
@ -0,0 +1,418 @@ |
|||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
// Licensed under the MIT License.
|
||||
|
|
||||
|
#pragma once |
||||
|
#include "onnxruntime_training_c_api.h" |
||||
|
#include <optional> |
||||
|
#include <variant> |
||||
|
|
||||
|
namespace Ort::detail { |
||||
|
|
||||
|
#define ORT_DECLARE_TRAINING_RELEASE(NAME) \ |
||||
|
void OrtRelease(Ort##NAME* ptr); |
||||
|
|
||||
|
// These release methods must be forward declared before including onnxruntime_cxx_api.h
|
||||
|
// otherwise class Base won't be aware of them
|
||||
|
ORT_DECLARE_TRAINING_RELEASE(CheckpointState); |
||||
|
ORT_DECLARE_TRAINING_RELEASE(TrainingSession); |
||||
|
|
||||
|
} // namespace Ort::detail
|
||||
|
|
||||
|
#include "onnxruntime_cxx_api.h" |
||||
|
|
||||
|
namespace Ort { |
||||
|
|
||||
|
/// <summary>
|
||||
|
/// This function returns the C training api struct with the pointers to the ort training C functions.
|
||||
|
/// If using C++, please use the class instances instead of invoking the C functions directly.
|
||||
|
/// </summary>
|
||||
|
/// <returns>OrtTrainingApi struct with ort training C function pointers.</returns>
|
||||
|
inline const OrtTrainingApi& GetTrainingApi() { return *GetApi().GetTrainingApi(ORT_API_VERSION); } |
||||
|
|
||||
|
namespace detail { |
||||
|
|
||||
|
#define ORT_DEFINE_TRAINING_RELEASE(NAME) \ |
||||
|
inline void OrtRelease(Ort##NAME* ptr) { GetTrainingApi().Release##NAME(ptr); } |
||||
|
|
||||
|
ORT_DEFINE_TRAINING_RELEASE(CheckpointState); |
||||
|
ORT_DEFINE_TRAINING_RELEASE(TrainingSession); |
||||
|
|
||||
|
#undef ORT_DECLARE_TRAINING_RELEASE |
||||
|
#undef ORT_DEFINE_TRAINING_RELEASE |
||||
|
|
||||
|
} // namespace detail
|
||||
|
|
||||
|
using Property = std::variant<int64_t, float, std::string>; |
||||
|
|
||||
|
/**
|
||||
|
* \defgroup TrainingCpp Ort Training C++ API |
||||
|
* @{ |
||||
|
*/ |
||||
|
|
||||
|
/** \brief Holds the state of the training session.
|
||||
|
* |
||||
|
* This class holds the entire training session state that includes model parameters, their gradients, |
||||
|
* optimizer parameters, and user properties. The Ort::TrainingSession leverages the Ort::CheckpointState |
||||
|
* by accessing and updating the contained training state. |
||||
|
* \note Note that the training session created with a checkpoint state uses this state to store the entire |
||||
|
* training state (including model parameters, its gradients, the optimizer states and the properties). |
||||
|
* The Ort::TrainingSession does not hold a copy of the Ort::CheckpointState and as a result, it is required |
||||
|
* that the checkpoint state outlive the lifetime of the training session. |
||||
|
* \note Note that the checkpoint state can be either the complete checkpoint state or the nominal checkpoint |
||||
|
* state depending on the version provided while loading the checkpoint. |
||||
|
* |
||||
|
*/ |
||||
|
class CheckpointState : public detail::Base<OrtCheckpointState> { |
||||
|
private: |
||||
|
CheckpointState(OrtCheckpointState* checkpoint_state) { p_ = checkpoint_state; } |
||||
|
|
||||
|
public: |
||||
|
// Construct the checkpoint state by loading the checkpoint by calling LoadCheckpoint
|
||||
|
CheckpointState() = delete; |
||||
|
|
||||
|
/// \name Accessing The Training Session State
|
||||
|
/// @{
|
||||
|
|
||||
|
/** \brief Load a checkpoint state from a file on disk into checkpoint_state.
|
||||
|
* |
||||
|
* This function will parse a checkpoint file, pull relevant data and load the training |
||||
|
* state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the |
||||
|
* training session by instantiating Ort::TrainingSession. By doing so, the training session will resume |
||||
|
* training from the given checkpoint state. |
||||
|
* |
||||
|
* \param[in] path_to_checkpoint Path to the checkpoint file |
||||
|
* \return Ort::CheckpointState object which holds the state of the training session parameters. |
||||
|
* |
||||
|
*/ |
||||
|
static CheckpointState LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint); |
||||
|
|
||||
|
/** \brief Load a checkpoint state from a buffer.
|
||||
|
* |
||||
|
* This function will parse a checkpoint buffer, pull relevant data and load the training |
||||
|
* state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the |
||||
|
* training session by instantiating Ort::TrainingSession. By doing so, the training session will resume |
||||
|
* training from the given checkpoint state. |
||||
|
* |
||||
|
* \param[in] buffer Buffer containing the checkpoint data. |
||||
|
* \return Ort::CheckpointState object which holds the state of the training session parameters. |
||||
|
* |
||||
|
*/ |
||||
|
static CheckpointState LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer); |
||||
|
|
||||
|
/** \brief Save the given state to a checkpoint file on disk.
|
||||
|
* |
||||
|
* This function serializes the provided checkpoint state to a file on disk. |
||||
|
* This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume |
||||
|
* training from this snapshot of the state. |
||||
|
* |
||||
|
* \param[in] checkpoint_state The checkpoint state to save. |
||||
|
* \param[in] path_to_checkpoint Path to the checkpoint file. |
||||
|
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not. |
||||
|
* |
||||
|
*/ |
||||
|
static void SaveCheckpoint(const CheckpointState& checkpoint_state, |
||||
|
const std::basic_string<ORTCHAR_T>& path_to_checkpoint, |
||||
|
const bool include_optimizer_state = false); |
||||
|
|
||||
|
/** \brief Adds or updates the given property to/in the checkpoint state.
|
||||
|
* |
||||
|
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint |
||||
|
* state by the user by calling this function with the corresponding property name and value. |
||||
|
* The given property name must be unique to be able to successfully add the property. |
||||
|
* |
||||
|
* \param[in] property_name Name of the property being added or updated. |
||||
|
* \param[in] property_value Property value associated with the given name. |
||||
|
* |
||||
|
*/ |
||||
|
void AddProperty(const std::string& property_name, const Property& property_value); |
||||
|
|
||||
|
/** \brief Gets the property value associated with the given name from the checkpoint state.
|
||||
|
* |
||||
|
* Gets the property value from an existing entry in the checkpoint state. The property must |
||||
|
* exist in the checkpoint state to be able to retrieve it successfully. |
||||
|
* |
||||
|
* \param[in] property_name Name of the property being retrieved. |
||||
|
* \return Property value associated with the given property name. |
||||
|
* |
||||
|
*/ |
||||
|
Property GetProperty(const std::string& property_name); |
||||
|
|
||||
|
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
|
||||
|
* |
||||
|
* This function updates a model parameter in the checkpoint state with the given parameter data. |
||||
|
* The training session must be already created with the checkpoint state that contains the parameter |
||||
|
* being updated. The given parameter is copied over to the registered device for the training session. |
||||
|
* The parameter must exist in the checkpoint state to be able to update it successfully. |
||||
|
* |
||||
|
* \param[in] parameter_name Name of the parameter being updated. |
||||
|
* \param[in] parameter The parameter data that should replace the existing parameter data. |
||||
|
* |
||||
|
*/ |
||||
|
void UpdateParameter(const std::string& parameter_name, const Value& parameter); |
||||
|
|
||||
|
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
|
||||
|
* |
||||
|
* This function retrieves the model parameter data from the checkpoint state for the given parameter name. |
||||
|
* The parameter is copied over to the provided OrtValue. The training session must be already created |
||||
|
* with the checkpoint state that contains the parameter being retrieved. |
||||
|
* The parameter must exist in the checkpoint state to be able to retrieve it successfully. |
||||
|
* |
||||
|
* \param[in] parameter_name Name of the parameter being retrieved. |
||||
|
* \return The parameter data that is retrieved from the checkpoint state. |
||||
|
* |
||||
|
*/ |
||||
|
Value GetParameter(const std::string& parameter_name); |
||||
|
|
||||
|
/// @}
|
||||
|
}; |
||||
|
|
||||
|
/** \brief Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.
|
||||
|
* |
||||
|
* The training session requires four training artifacts |
||||
|
* - The training onnx model |
||||
|
* - The evaluation onnx model (optional) |
||||
|
* - The optimizer onnx model |
||||
|
* - The checkpoint file |
||||
|
* |
||||
|
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
|
||||
|
* |
||||
|
*/ |
||||
|
class TrainingSession : public detail::Base<OrtTrainingSession> { |
||||
|
private: |
||||
|
size_t training_model_output_count_, eval_model_output_count_; |
||||
|
|
||||
|
public: |
||||
|
/// \name Constructing the Training Session
|
||||
|
/// @{
|
||||
|
/** \brief Create a training session that can be used to begin or resume training.
|
||||
|
* |
||||
|
* This constructor instantiates the training session based on the env and session options provided that can |
||||
|
* begin or resume training from a given checkpoint state for the given onnx models. |
||||
|
* The checkpoint state represents the parameters of the training session which will be moved |
||||
|
* to the device specified by the user through the session options (if necessary). |
||||
|
* |
||||
|
* \param[in] env Env to be used for the training session. |
||||
|
* \param[in] session_options SessionOptions that the user can customize for this training session. |
||||
|
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training. |
||||
|
* \param[in] train_model_path Model to be used to perform training. |
||||
|
* \param[in] eval_model_path Model to be used to perform evaluation. |
||||
|
* \param[in] optimizer_model_path Model to be used to perform gradient descent. |
||||
|
* |
||||
|
*/ |
||||
|
TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state, |
||||
|
const std::basic_string<ORTCHAR_T>& train_model_path, |
||||
|
const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path = std::nullopt, |
||||
|
const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path = std::nullopt); |
||||
|
|
||||
|
/** \brief Create a training session that can be used to begin or resume training.
|
||||
|
* This constructor allows the users to load the models from buffers instead of files. |
||||
|
* |
||||
|
* \param[in] env Env to be used for the training session. |
||||
|
* \param[in] session_options SessionOptions that the user can customize for this training session. |
||||
|
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training. |
||||
|
* \param[in] train_model_data Buffer containing training model data. |
||||
|
* \param[in] eval_model_data Buffer containing evaluation model data. |
||||
|
* \param[in] optim_model_data Buffer containing optimizer model (used for performing weight/parameter update). |
||||
|
* |
||||
|
*/ |
||||
|
TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state, |
||||
|
const std::vector<uint8_t>& train_model_data, const std::vector<uint8_t>& eval_model_data = {}, |
||||
|
const std::vector<uint8_t>& optim_model_data = {}); |
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Implementing The Training Loop
|
||||
|
/// @{
|
||||
|
/** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
|
||||
|
* |
||||
|
* This function performs a training step that computes the outputs of the training model and the gradients |
||||
|
* of the trainable parameters for the given inputs. The train step is performed based on the training model |
||||
|
* that was provided to the training session. |
||||
|
* The Ort::TrainingSession::TrainStep is equivalent of running forward propagation and backward propagation in a single |
||||
|
* step. |
||||
|
* The gradients computed are stored inside the training session state so they can be later consumed |
||||
|
* by the Ort::TrainingSession::OptimizerStep function. |
||||
|
* The gradients can be lazily reset by invoking the Ort::TrainingSession::LazyResetGrad function. |
||||
|
* |
||||
|
* \param[in] input_values The user inputs to the training model. |
||||
|
* \return A std::vector of Ort::Value objects that represents the output of the forward pass of the training model. |
||||
|
* |
||||
|
* |
||||
|
*/ |
||||
|
std::vector<Value> TrainStep(const std::vector<Value>& input_values); |
||||
|
|
||||
|
/** \brief Reset the gradients of all trainable parameters to zero lazily.
|
||||
|
* |
||||
|
* This function sets the internal state of the training session such that the gradients of the trainable |
||||
|
* parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are |
||||
|
* computed on the next invocation of the next Ort::TrainingSession::TrainStep. |
||||
|
* |
||||
|
*/ |
||||
|
void LazyResetGrad(); |
||||
|
|
||||
|
/** \brief Computes the outputs for the eval model for the given inputs
|
||||
|
* |
||||
|
* This function performs an eval step that computes the outputs of the eval model for the given inputs. |
||||
|
* The eval step is performed based on the eval model that was provided to the training session. |
||||
|
* |
||||
|
* \param[in] input_values The user inputs to the eval model. |
||||
|
* \return A std::vector of Ort::Value objects that represents the output of the eval pass. |
||||
|
* |
||||
|
*/ |
||||
|
std::vector<Value> EvalStep(const std::vector<Value>& input_values); |
||||
|
|
||||
|
/** \brief Sets the learning rate for this training session.
|
||||
|
* |
||||
|
* This function allows users to set the learning rate for the training session. The current |
||||
|
* learning rate is maintained by the training session and can be overwritten by invoking |
||||
|
* this function with the desired learning rate. This function should not be used when a valid |
||||
|
* learning rate scheduler is registered. It should be used either to set the learning rate |
||||
|
* derived from a custom learning rate scheduler or to set a constant learning rate to be used |
||||
|
* throughout the training session. |
||||
|
* \note Please note that this function does not set the initial learning rate that may be needed |
||||
|
* by the predefined learning rate schedulers. To set the initial learning rate for learning |
||||
|
* rate schedulers, please look at the function Ort::TrainingSession::RegisterLinearLRScheduler. |
||||
|
* |
||||
|
* \param[in] learning_rate Desired learning rate to be set. |
||||
|
* |
||||
|
*/ |
||||
|
void SetLearningRate(float learning_rate); |
||||
|
|
||||
|
/** \brief Gets the current learning rate for this training session.
|
||||
|
* |
||||
|
* This function allows users to get the learning rate for the training session. The current |
||||
|
* learning rate is maintained by the training session, and users can query it for the purpose |
||||
|
* of implementing their own learning rate schedulers. |
||||
|
* |
||||
|
* \return float representing the current learning rate. |
||||
|
* |
||||
|
*/ |
||||
|
float GetLearningRate() const; |
||||
|
|
||||
|
/** \brief Registers a linear learning rate scheduler for the training session.
|
||||
|
* |
||||
|
* Register a linear learning rate scheduler that decays the learning rate by linearly updated |
||||
|
* multiplicative factor from the initial learning rate set on the training session to 0. The decay |
||||
|
* is performed after the initial warm up phase where the learning rate is linearly incremented |
||||
|
* from 0 to the initial learning rate provided. |
||||
|
* |
||||
|
* \param[in] warmup_step_count Warmup steps for LR warmup. |
||||
|
* \param[in] total_step_count Total step count. |
||||
|
* \param[in] initial_lr The initial learning rate to be used by the training session. |
||||
|
* |
||||
|
*/ |
||||
|
void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count, |
||||
|
float initial_lr); |
||||
|
|
||||
|
/** \brief Update the learning rate based on the registered learing rate scheduler.
|
||||
|
* |
||||
|
* Takes a scheduler step that updates the learning rate that is being used by the training session. |
||||
|
* This function should typically be called before invoking the optimizer step for each round, |
||||
|
* or as determined necessary to update the learning rate being used by the training session. |
||||
|
* \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this |
||||
|
* function. |
||||
|
* |
||||
|
*/ |
||||
|
void SchedulerStep(); |
||||
|
|
||||
|
/** \brief Performs the weight updates for the trainable parameters using the optimizer model.
|
||||
|
* |
||||
|
* This function performs the weight update step that updates the trainable parameters such that they |
||||
|
* take a step in the direction of their gradients (gradient descent). The optimizer step is performed |
||||
|
* based on the optimizer model that was provided to the training session. |
||||
|
* The updated parameters are stored inside the training state so that they can be used by the next |
||||
|
* Ort::TrainingSession::TrainStep function call. |
||||
|
* |
||||
|
*/ |
||||
|
void OptimizerStep(); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Prepare For Inferencing
|
||||
|
/// @{
|
||||
|
|
||||
|
/** \brief Export a model that can be used for inferencing.
|
||||
|
* |
||||
|
* If the training session was provided with an eval model, the training session can generate |
||||
|
* an inference model if it knows the inference graph outputs. The input inference graph outputs |
||||
|
* are used to prune the eval model so that the inference model's outputs align with the provided outputs. |
||||
|
* The exported model is saved at the path provided and can be used for inferencing with Ort::Session. |
||||
|
* \note Note that the function re-loads the eval model from the path provided to Ort::TrainingSession |
||||
|
* and expects that this path still be valid. |
||||
|
* |
||||
|
* \param[in] inference_model_path Path where the inference model should be serialized to. |
||||
|
* \param[in] graph_output_names Names of the outputs that are needed in the inference model. |
||||
|
* |
||||
|
*/ |
||||
|
void ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path, |
||||
|
const std::vector<std::string>& graph_output_names); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Model IO Information
|
||||
|
/// @{
|
||||
|
/** \brief Retrieves the names of the user inputs for the training and eval models.
|
||||
|
* |
||||
|
* This function returns the names of inputs of the training or eval model that can be associated |
||||
|
* with the Ort::Value(s) provided to the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep |
||||
|
* function. |
||||
|
* |
||||
|
* \param[in] training Whether the training model input names are requested or eval model input names. |
||||
|
* \return Graph input names for either the training model or the eval model. |
||||
|
* |
||||
|
*/ |
||||
|
std::vector<std::string> InputNames(const bool training); |
||||
|
|
||||
|
/** \brief Retrieves the names of the user outputs for the training and eval models.
|
||||
|
* |
||||
|
* This function returns the names of outputs of the training or eval model that can be associated |
||||
|
* with the Ort::Value(s) returned by the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep |
||||
|
* function. |
||||
|
* |
||||
|
* \param[in] training Whether the training model output names are requested or eval model output names. |
||||
|
* \return Graph output names for either the training model or the eval model. |
||||
|
* |
||||
|
*/ |
||||
|
std::vector<std::string> OutputNames(const bool training); |
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
/// \name Accessing The Training Session State
|
||||
|
/// @{
|
||||
|
|
||||
|
/** \brief Returns a contiguous buffer that holds a copy of all training state parameters
|
||||
|
* |
||||
|
* \param[in] only_trainable Whether to only copy trainable parameters or to copy all parameters. |
||||
|
* \return Contiguous buffer to the model parameters. |
||||
|
* |
||||
|
*/ |
||||
|
Value ToBuffer(const bool only_trainable); |
||||
|
|
||||
|
/** \brief Loads the training session model parameters from a contiguous buffer
|
||||
|
* |
||||
|
* In case the training session was created with a nominal checkpoint, invoking this function is required |
||||
|
* to load the updated parameters onto the checkpoint to complete it. |
||||
|
* |
||||
|
* \param[in] buffer Contiguous buffer to load the parameters from. |
||||
|
*/ |
||||
|
void FromBuffer(Value& buffer); |
||||
|
|
||||
|
/// @}
|
||||
|
}; |
||||
|
|
||||
|
/// \name Training Utilities
|
||||
|
/// @{
|
||||
|
/** \brief This function sets the seed for generating random numbers.
|
||||
|
* |
||||
|
* Use this function to generate reproducible results. It should be noted that completely |
||||
|
* reproducible results are not guaranteed. |
||||
|
* |
||||
|
* \param[in] seed Manual seed to use for random number generation. |
||||
|
*/ |
||||
|
void SetSeed(const int64_t seed); |
||||
|
/// @}
|
||||
|
|
||||
|
/// @}
|
||||
|
|
||||
|
} // namespace Ort
|
||||
|
|
||||
|
#include "onnxruntime_training_cxx_inline.h" |
@ -0,0 +1,295 @@ |
|||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
// Licensed under the MIT License.
|
||||
|
|
||||
|
#pragma once |
||||
|
#include "onnxruntime_training_c_api.h" |
||||
|
#include "onnxruntime_cxx_api.h" |
||||
|
|
||||
|
namespace Ort { |
||||
|
|
||||
|
inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options, |
||||
|
CheckpointState& checkpoint_state, |
||||
|
const std::basic_string<ORTCHAR_T>& train_model_path, |
||||
|
const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path, |
||||
|
const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path) { |
||||
|
ThrowOnError(GetTrainingApi().CreateTrainingSession( |
||||
|
env, session_options, checkpoint_state, |
||||
|
train_model_path.c_str(), |
||||
|
eval_model_path.has_value() ? eval_model_path.value().c_str() : nullptr, |
||||
|
optimizer_model_path.has_value() ? optimizer_model_path.value().c_str() : nullptr, |
||||
|
&p_)); |
||||
|
|
||||
|
ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_)); |
||||
|
|
||||
|
ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); |
||||
|
} |
||||
|
|
||||
|
inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options, |
||||
|
CheckpointState& checkpoint_state, |
||||
|
const std::vector<uint8_t>& train_model_data, |
||||
|
const std::vector<uint8_t>& eval_model_data, |
||||
|
const std::vector<uint8_t>& optim_model_data) { |
||||
|
ThrowOnError(GetTrainingApi().CreateTrainingSessionFromBuffer( |
||||
|
env, session_options, checkpoint_state, |
||||
|
train_model_data.data(), train_model_data.size(), |
||||
|
eval_model_data.data(), eval_model_data.size(), |
||||
|
optim_model_data.data(), optim_model_data.size(), |
||||
|
&p_)); |
||||
|
|
||||
|
ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_)); |
||||
|
|
||||
|
ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); |
||||
|
} |
||||
|
|
||||
|
inline std::vector<Value> TrainingSession::TrainStep(const std::vector<Value>& input_values) { |
||||
|
std::vector<Value> output_values; |
||||
|
output_values.reserve(training_model_output_count_); |
||||
|
for (size_t i = 0; i < training_model_output_count_; i++) output_values.emplace_back(nullptr); |
||||
|
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values.data()); |
||||
|
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values.data()); |
||||
|
RunOptions run_options; |
||||
|
ThrowOnError(GetTrainingApi().TrainStep( |
||||
|
p_, run_options, input_values.size(), ort_input_values, |
||||
|
training_model_output_count_, ort_output_values)); |
||||
|
|
||||
|
return output_values; |
||||
|
} |
||||
|
|
||||
|
inline void TrainingSession::LazyResetGrad() { |
||||
|
ThrowOnError(GetTrainingApi().LazyResetGrad(p_)); |
||||
|
} |
||||
|
|
||||
|
inline std::vector<Value> TrainingSession::EvalStep(const std::vector<Value>& input_values) { |
||||
|
std::vector<Value> output_values; |
||||
|
output_values.reserve(eval_model_output_count_); |
||||
|
for (size_t i = 0; i < eval_model_output_count_; i++) output_values.emplace_back(nullptr); |
||||
|
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values.data()); |
||||
|
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values.data()); |
||||
|
RunOptions run_options; |
||||
|
ThrowOnError(GetTrainingApi().EvalStep( |
||||
|
p_, run_options, input_values.size(), ort_input_values, |
||||
|
eval_model_output_count_, ort_output_values)); |
||||
|
|
||||
|
return output_values; |
||||
|
} |
||||
|
|
||||
|
inline void TrainingSession::SetLearningRate(float learning_rate) { |
||||
|
ThrowOnError(GetTrainingApi().SetLearningRate(p_, learning_rate)); |
||||
|
} |
||||
|
|
||||
|
inline float TrainingSession::GetLearningRate() const { |
||||
|
float learning_rate = 0; |
||||
|
ThrowOnError(GetTrainingApi().GetLearningRate(p_, &learning_rate)); |
||||
|
return learning_rate; |
||||
|
} |
||||
|
|
||||
|
inline void TrainingSession::RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count, |
||||
|
float initial_lr) { |
||||
|
ThrowOnError(GetTrainingApi().RegisterLinearLRScheduler(p_, warmup_step_count, total_step_count, |
||||
|
initial_lr)); |
||||
|
} |
||||
|
|
||||
|
inline void TrainingSession::SchedulerStep() { |
||||
|
ThrowOnError(GetTrainingApi().SchedulerStep(p_)); |
||||
|
} |
||||
|
|
||||
|
inline void TrainingSession::OptimizerStep() { |
||||
|
RunOptions run_options; |
||||
|
ThrowOnError(GetTrainingApi().OptimizerStep(p_, run_options)); |
||||
|
} |
||||
|
|
||||
|
inline std::vector<std::string> TrainingSession::InputNames(const bool training) { |
||||
|
auto& input_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputCount |
||||
|
: GetTrainingApi().TrainingSessionGetEvalModelInputCount; |
||||
|
auto& input_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputName |
||||
|
: GetTrainingApi().TrainingSessionGetEvalModelInputName; |
||||
|
|
||||
|
size_t input_count = 0; |
||||
|
ThrowOnError(input_count_function(p_, &input_count)); |
||||
|
std::vector<std::string> input_names(input_count); |
||||
|
AllocatorWithDefaultOptions allocator; |
||||
|
for (size_t index = 0; index < input_count; ++index) { |
||||
|
char* input_name; |
||||
|
ThrowOnError(input_name_function(p_, index, allocator, &input_name)); |
||||
|
input_names[index] = std::string(input_name); |
||||
|
allocator.Free(input_name); |
||||
|
} |
||||
|
|
||||
|
return input_names; |
||||
|
} |
||||
|
|
||||
|
inline std::vector<std::string> TrainingSession::OutputNames(const bool training) { |
||||
|
auto& output_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputCount |
||||
|
: GetTrainingApi().TrainingSessionGetEvalModelOutputCount; |
||||
|
auto& output_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputName |
||||
|
: GetTrainingApi().TrainingSessionGetEvalModelOutputName; |
||||
|
|
||||
|
size_t output_count = 0; |
||||
|
ThrowOnError(output_count_function(p_, &output_count)); |
||||
|
std::vector<std::string> output_names(output_count); |
||||
|
AllocatorWithDefaultOptions allocator; |
||||
|
for (size_t index = 0; index < output_count; ++index) { |
||||
|
char* output_name; |
||||
|
ThrowOnError(output_name_function(p_, index, allocator, &output_name)); |
||||
|
output_names[index] = std::string(output_name); |
||||
|
allocator.Free(output_name); |
||||
|
} |
||||
|
|
||||
|
return output_names; |
||||
|
} |
||||
|
|
||||
|
inline Value TrainingSession::ToBuffer(const bool only_trainable) { |
||||
|
size_t buffer_size = 0U; |
||||
|
ThrowOnError(GetTrainingApi().GetParametersSize(p_, &buffer_size, only_trainable)); |
||||
|
|
||||
|
std::array<int64_t, 1> buffer_shape{static_cast<int64_t>(buffer_size)}; |
||||
|
|
||||
|
AllocatorWithDefaultOptions allocator; |
||||
|
Value buffer = Value::CreateTensor(allocator, buffer_shape.data(), 1U, |
||||
|
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); |
||||
|
|
||||
|
ThrowOnError(GetTrainingApi().CopyParametersToBuffer(p_, buffer, only_trainable)); |
||||
|
|
||||
|
return buffer; |
||||
|
} |
||||
|
|
||||
|
inline void TrainingSession::FromBuffer(Value& buffer) { |
||||
|
if (!buffer.IsTensor()) { |
||||
|
ThrowStatus(Status("Incorrect buffer received. Expected a tensor buffer.", OrtErrorCode::ORT_INVALID_ARGUMENT)); |
||||
|
} |
||||
|
|
||||
|
auto tensor_info = buffer.GetTensorTypeAndShapeInfo(); |
||||
|
auto buffer_shape = tensor_info.GetShape(); |
||||
|
|
||||
|
if (buffer_shape.size() != 1U) { |
||||
|
ThrowStatus(Status("Incorrect buffer received. Expected a contiguous tensor buffer.", |
||||
|
OrtErrorCode::ORT_INVALID_ARGUMENT)); |
||||
|
} |
||||
|
|
||||
|
auto buffer_size = buffer_shape.front(); |
||||
|
|
||||
|
size_t session_buffer_size = 0U; |
||||
|
ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false)); |
||||
|
|
||||
|
if (buffer_size == static_cast<int64_t>(session_buffer_size)) { |
||||
|
ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false)); |
||||
|
return; |
||||
|
} |
||||
|
|
||||
|
size_t session_buffer_size_trainable_only = 0U; |
||||
|
ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size_trainable_only, true)); |
||||
|
|
||||
|
if (buffer_size == static_cast<int64_t>(session_buffer_size_trainable_only)) { |
||||
|
ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, true)); |
||||
|
return; |
||||
|
} else { |
||||
|
ThrowStatus(Status("Incorrect buffer size received.", OrtErrorCode::ORT_INVALID_ARGUMENT)); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint) { |
||||
|
OrtCheckpointState* checkpoint_state; |
||||
|
ThrowOnError(GetTrainingApi().LoadCheckpoint(path_to_checkpoint.c_str(), &checkpoint_state)); |
||||
|
return CheckpointState(checkpoint_state); |
||||
|
} |
||||
|
|
||||
|
inline CheckpointState CheckpointState::LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer) { |
||||
|
OrtCheckpointState* checkpoint_state; |
||||
|
ThrowOnError(GetTrainingApi().LoadCheckpointFromBuffer(buffer.data(), buffer.size(), &checkpoint_state)); |
||||
|
return CheckpointState(checkpoint_state); |
||||
|
} |
||||
|
|
||||
|
inline void CheckpointState::SaveCheckpoint(const CheckpointState& checkpoint_states, |
||||
|
const std::basic_string<ORTCHAR_T>& path_to_checkpoint, |
||||
|
const bool include_optimizer_state) { |
||||
|
ThrowOnError(GetTrainingApi().SaveCheckpoint(checkpoint_states, path_to_checkpoint.c_str(), |
||||
|
include_optimizer_state)); |
||||
|
} |
||||
|
|
||||
|
inline void TrainingSession::ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path, |
||||
|
const std::vector<std::string>& graph_output_names) { |
||||
|
std::vector<const char*> output_names; |
||||
|
output_names.reserve(graph_output_names.size()); |
||||
|
for (const auto& output_name : graph_output_names) { |
||||
|
output_names.push_back(output_name.c_str()); |
||||
|
} |
||||
|
ThrowOnError(GetTrainingApi().ExportModelForInferencing( |
||||
|
p_, inference_model_path.c_str(), graph_output_names.size(), output_names.data())); |
||||
|
} |
||||
|
|
||||
|
inline void SetSeed(const int64_t seed) { |
||||
|
ThrowOnError(GetTrainingApi().SetSeed(seed)); |
||||
|
} |
||||
|
|
||||
|
inline void CheckpointState::AddProperty(const std::string& property_name, const Property& property_value) { |
||||
|
if (std::holds_alternative<int64_t>(property_value)) { |
||||
|
int64_t value = std::get<int64_t>(property_value); |
||||
|
void* value_p = &value; |
||||
|
ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtIntProperty, value_p)); |
||||
|
} else if (std::holds_alternative<float>(property_value)) { |
||||
|
float value = std::get<float>(property_value); |
||||
|
void* value_p = &value; |
||||
|
ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtFloatProperty, value_p)); |
||||
|
} else if (std::holds_alternative<std::string>(property_value)) { |
||||
|
std::string value = std::get<std::string>(property_value); |
||||
|
auto buffer = std::make_unique<char[]>(value.length() + 1); |
||||
|
memcpy(buffer.get(), value.c_str(), value.length()); |
||||
|
// AddProperty takes a char* and calls PropertyBag::AddProperty which takes a std::string. The data will be
|
||||
|
// copied at that point so buffer can free the local allocation once the call is made.
|
||||
|
ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtStringProperty, |
||||
|
buffer.get())); |
||||
|
} else { |
||||
|
ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT)); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
inline Property CheckpointState::GetProperty(const std::string& property_name) { |
||||
|
void* property_value = nullptr; |
||||
|
OrtPropertyType property_type; |
||||
|
|
||||
|
AllocatorWithDefaultOptions allocator; |
||||
|
ThrowOnError(GetTrainingApi().GetProperty(p_, property_name.c_str(), allocator, &property_type, &property_value)); |
||||
|
|
||||
|
Property property; |
||||
|
|
||||
|
switch (property_type) { |
||||
|
case OrtPropertyType::OrtIntProperty: { |
||||
|
auto value_p = reinterpret_cast<int64_t*>(property_value); |
||||
|
property = *value_p; |
||||
|
allocator.Free(property_value); |
||||
|
break; |
||||
|
} |
||||
|
case OrtPropertyType::OrtFloatProperty: { |
||||
|
auto value_p = reinterpret_cast<float*>(property_value); |
||||
|
property = *value_p; |
||||
|
allocator.Free(property_value); |
||||
|
break; |
||||
|
} |
||||
|
case OrtPropertyType::OrtStringProperty: { |
||||
|
auto value_p = reinterpret_cast<char*>(property_value); |
||||
|
property = std::string(value_p); |
||||
|
allocator.Free(property_value); |
||||
|
break; |
||||
|
} |
||||
|
default: { |
||||
|
ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT)); |
||||
|
break; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return property; |
||||
|
} |
||||
|
|
||||
|
inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) { |
||||
|
ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter)); |
||||
|
} |
||||
|
|
||||
|
inline Value CheckpointState::GetParameter(const std::string& parameter_name) { |
||||
|
AllocatorWithDefaultOptions allocator; |
||||
|
OrtValue* parameter; |
||||
|
ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter)); |
||||
|
|
||||
|
return Value{parameter}; |
||||
|
} |
||||
|
|
||||
|
} // namespace Ort
|
@ -1,14 +0,0 @@ |
|||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||
// Licensed under the MIT License.
|
|
||||
|
|
||||
#include "onnxruntime_c_api.h" |
|
||||
|
|
||||
#ifdef __cplusplus |
|
||||
extern "C" { |
|
||||
#endif |
|
||||
|
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); |
|
||||
|
|
||||
#ifdef __cplusplus |
|
||||
} |
|
||||
#endif |
|
Binary file not shown.
Loading…
Reference in new issue