diff --git a/.gitignore b/.gitignore index 4f4db76..caf2099 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Ignoring onnxruntime libs +/libs/onnxruntime/lib/* + example-*/config.make example-*/*.sln example-*/*.vcxproj diff --git a/README.md b/README.md index 429277f..4a83769 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ # ofxOnnxRuntime + +**Updated version, working with Windows 11, CUDA, and ONNXRuntime 1.20.1** + [ONNX Runtime](https://github.com/microsoft/onnxruntime) tiny wrapper for openFrameworks !['test'](screenshot.png) @@ -17,7 +20,7 @@ - From `Browse` tab, search `Microsoft.ML.OnnxRuntime` (CPU) or `Microsoft.ML.OnnxRuntime.Gpu` (GPU) and install it. 2. DLL direct download - You can download prebuilt DLLs from [here](https://github.com/microsoft/onnxruntime/releases). - - Unzip downloaded `onnxruntime-win-x64-(gpu-)1.10.0.zip` and locate files on `libs\onnxruntime\lib\vs\x64\` . + - Unzip downloaded `onnxruntime-win-x64-(gpu-)1.20.1.zip` and locate files on `libs\onnxruntime\lib\vs\x64\` . - Generate a project using ProjectGenerator, then all libs are linked correctly and all dlls are copied to `bin`. ## Tested environment diff --git a/addon_config.mk b/addon_config.mk index 526a0ef..704f4d8 100644 --- a/addon_config.mk +++ b/addon_config.mk @@ -11,4 +11,6 @@ common: osx: ADDON_LDFLAGS = -Xlinker -rpath -Xlinker @executable_path vs: + ADDON_INCLUDES = libs/onnxruntime/include + ADDON_INCLUDES += src diff --git a/libs/onnxruntime/include/core/providers/cuda/cuda_context.h b/libs/onnxruntime/include/core/providers/cuda/cuda_context.h new file mode 100644 index 0000000..12a1975 --- /dev/null +++ b/libs/onnxruntime/include/core/providers/cuda/cuda_context.h @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This header is to expose a context for cuda custom ops. +// By the context, a custom cuda operator could fetch existing resources, +// such as cuda stream and cudnn handle, for reusing. + +// For concrete usage, pls find page here: +// https://onnxruntime.ai/docs/reference/operators/add-custom-op.html#custom-ops-for-cuda-and-rocm + +#pragma once + +#define ORT_CUDA_CTX + +#include +#include +#ifndef USE_CUDA_MINIMAL +#include +#include +#endif + +#include "core/providers/cuda/cuda_resource.h" +#include "core/providers/custom_op_context.h" + +namespace Ort { + +namespace Custom { + +struct CudaContext : public CustomOpContext { + cudaStream_t cuda_stream = {}; + cudnnHandle_t cudnn_handle = {}; + cublasHandle_t cublas_handle = {}; + OrtAllocator* deferred_cpu_allocator = {}; + // below are cuda ep options + int16_t device_id = 0; + int32_t arena_extend_strategy = 0; + int32_t cudnn_conv_algo_search = 0; + bool cudnn_conv_use_max_workspace = true; + bool cudnn_conv1d_pad_to_nc1d = false; + bool enable_skip_layer_norm_strict_mode = false; + bool prefer_nhwc = false; + bool use_tf32 = true; + bool fuse_conv_bias = true; + + void Init(const OrtKernelContext& kernel_ctx) { + cuda_stream = FetchResource(kernel_ctx, CudaResource::cuda_stream_t); + cudnn_handle = FetchResource(kernel_ctx, CudaResource::cudnn_handle_t); + cublas_handle = FetchResource(kernel_ctx, CudaResource::cublas_handle_t); + deferred_cpu_allocator = FetchResource(kernel_ctx, CudaResource::deferred_cpu_allocator_t); + + device_id = FetchResource(kernel_ctx, CudaResource::device_id_t); + arena_extend_strategy = FetchResource(kernel_ctx, CudaResource::arena_extend_strategy_t); + cudnn_conv_algo_search = FetchResource(kernel_ctx, CudaResource::cudnn_conv_algo_search_t); + cudnn_conv_use_max_workspace = FetchResource(kernel_ctx, CudaResource::cudnn_conv_use_max_workspace_t); + + cudnn_conv1d_pad_to_nc1d = FetchResource(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t); + enable_skip_layer_norm_strict_mode = FetchResource( + kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t); + prefer_nhwc = FetchResource(kernel_ctx, CudaResource::prefer_nhwc_t); + use_tf32 = FetchResource(kernel_ctx, CudaResource::use_tf32_t); + fuse_conv_bias = FetchResource(kernel_ctx, CudaResource::fuse_conv_bias_t); + } + + template + T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) { + if constexpr (sizeof(T) > sizeof(void*)) { + ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), + OrtErrorCode::ORT_INVALID_ARGUMENT); + } + const auto& ort_api = Ort::GetApi(); + void* resource = {}; + OrtStatus* status = ort_api.KernelContext_GetResource( + &kernel_ctx, ORT_CUDA_RESOURCE_VERSION, resource_type, &resource); + if (status) { + ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resource type: " + std::to_string(resource_type), + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + T t = {}; + memcpy(&t, &resource, sizeof(T)); + return t; + } + + void* AllocDeferredCpuMem(size_t size) const { + if (0 == size) { + return {}; + } + const auto& ort_api = Ort::GetApi(); + void* mem = {}; + auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem); + if (status) { + ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return mem; + } + + void FreeDeferredCpuMem(void* mem) const { + if (mem) { + const auto& ort_api = Ort::GetApi(); + auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem); + if (status) { + ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + } + } +}; + +} // namespace Custom +} // namespace Ort diff --git a/libs/onnxruntime/include/core/providers/cuda/cuda_resource.h b/libs/onnxruntime/include/core/providers/cuda/cuda_resource.h new file mode 100644 index 0000000..b248d33 --- /dev/null +++ b/libs/onnxruntime/include/core/providers/cuda/cuda_resource.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/resource.h" + +#define ORT_CUDA_RESOURCE_VERSION 3 + +enum CudaResource : int { + cuda_stream_t = cuda_resource_offset, // 10000 + cudnn_handle_t, + cublas_handle_t, + deferred_cpu_allocator_t, + // below are cuda ep options + device_id_t, // 10004 + arena_extend_strategy_t, + cudnn_conv_algo_search_t, + cudnn_conv_use_max_workspace_t, + cudnn_conv1d_pad_to_nc1d_t, + enable_skip_layer_norm_strict_mode_t, + prefer_nhwc_t, + use_tf32_t, + fuse_conv_bias_t +}; diff --git a/libs/onnxruntime/include/core/providers/custom_op_context.h b/libs/onnxruntime/include/core/providers/custom_op_context.h new file mode 100644 index 0000000..b10126d --- /dev/null +++ b/libs/onnxruntime/include/core/providers/custom_op_context.h @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +// CustomOpContext defines an interface allowing a custom op to access ep-specific resources. +struct CustomOpContext { + CustomOpContext() = default; + virtual ~CustomOpContext() {}; +}; \ No newline at end of file diff --git a/libs/onnxruntime/include/core/providers/resource.h b/libs/onnxruntime/include/core/providers/resource.h new file mode 100644 index 0000000..bd123e1 --- /dev/null +++ b/libs/onnxruntime/include/core/providers/resource.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +enum ResourceOffset { + cpu_resource_offset = 0, + cuda_resource_offset = 10000, + dml_resource_offset = 20000, + rocm_resource_offset = 30000, + // offsets for other ort eps + custom_ep_resource_offset = 10000000, + // offsets for customized eps +}; \ No newline at end of file diff --git a/libs/onnxruntime/include/onnxruntime_c_api.h b/libs/onnxruntime/include/onnxruntime_c_api.h index b949be4..fcf3239 100644 --- a/libs/onnxruntime/include/onnxruntime_c_api.h +++ b/libs/onnxruntime/include/onnxruntime_c_api.h @@ -3,34 +3,42 @@ // See docs\c_cxx\README.md on generating the Doxygen documentation from this file -/** \mainpage C & C++ APIs -* -*

C

-* -* ::OrtApi - Click here to jump to the structure with all C API functions. -* -*

C++

-* -* ::Ort - Click here to jump to the namespace holding all of the C++ wrapper classes -* -* It is a set of header only wrapper classes around the C API. The goal is to turn the C style return value error codes into C++ exceptions, and to -* automate memory management through standard C++ RAII principles. -* -* \addtogroup Global -* ONNX Runtime C API -* @{ -*/ +/** \mainpage ONNX Runtime + * + * ONNX Runtime is a high-performance inference and training graph execution engine for deep learning models. + * + * ONNX Runtime's C, C++ APIs offer an easy to use interface to onboard and execute onnx models. + * - \subpage c_cpp_api "Core C, C++ APIs" + * - \subpage training_c_cpp_api "Training C, C++ APIs for on-device training" + * + * \page c_cpp_api Core C, C++ APIs + *

C

+ * + * ::OrtApi - Click here to go to the structure with all C API functions. + * + *

C++

+ * + * ::Ort - Click here to go to the namespace holding all of the C++ wrapper classes + * + * It is a set of header only wrapper classes around the C API. The goal is to turn the C style return value error codes into C++ exceptions, and to + * automate memory management through standard C++ RAII principles. + * + * \addtogroup Global + * ONNX Runtime C API + * @{ + */ #pragma once -#include +#include #include +#include #include /** \brief The API version defined in this header -* -* This value is used by some API functions to behave as this version of the header expects. -*/ -#define ORT_API_VERSION 10 + * + * This value is used by some API functions to behave as this version of the header expects. + */ +#define ORT_API_VERSION 20 #ifdef __cplusplus extern "C" { @@ -54,6 +62,8 @@ extern "C" { #define _Check_return_ #define _Outptr_result_maybenull_ #define _In_reads_(X) +#define _Inout_updates_(X) +#define _Out_writes_(X) #define _Inout_updates_all_(X) #define _Out_writes_bytes_all_(X) #define _Out_writes_all_(X) @@ -88,12 +98,24 @@ extern "C" { #define ORTCHAR_T char #endif +/// ORTCHAR_T, ORT_TSTR are reserved specifically for path handling. +/// All other strings are UTF-8 encoded, use char and std::string #ifndef ORT_TSTR #ifdef _WIN32 #define ORT_TSTR(X) L##X +// When X is a macro, L##X is not defined. In this case, we need to use ORT_TSTR_ON_MACRO. +#define ORT_TSTR_ON_MACRO(X) L"" X #else #define ORT_TSTR(X) X +#define ORT_TSTR_ON_MACRO(X) X +#endif #endif + +// On Windows, ORT_FILE is a wchar_t version of the __FILE__ macro. +// Otherwise, ORT_FILE is equivalent to __FILE__. +#ifndef ORT_FILE +#define ORT_FILE_INTERNAL(x) ORT_TSTR(x) +#define ORT_FILE ORT_FILE_INTERNAL(__FILE__) #endif // Any pointer marked with _In_ or _Out_, cannot be NULL. @@ -120,8 +142,9 @@ extern "C" { // __VA_ARGS__ on Windows and Linux are different #define ORT_API(RETURN_TYPE, NAME, ...) RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION -#define ORT_API_STATUS(NAME, ...) \ - _Success_(return == 0) _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT +#define ORT_API_STATUS(NAME, ...) \ + _Success_(return == 0) _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) \ + NO_EXCEPTION ORT_MUST_USE_RESULT // XXX: Unfortunately, SAL annotations are known to not work with function pointers #define ORT_API2_STATUS(NAME, ...) \ @@ -149,8 +172,8 @@ extern "C" { */ /** Copied from TensorProto::DataType -* Currently, Ort doesn't support complex64, complex128 -*/ + * Currently, Ort doesn't support complex64, complex128 + */ typedef enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, // maps to c type float @@ -168,7 +191,15 @@ typedef enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components - ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, // Non-IEEE floating-point format based on IEEE754 single-precision + // float 8 types were introduced in onnx 1.14, see https://onnx.ai/onnx/technical/float8.html + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision + // Int4 types were introduced in ONNX 1.16. See https://onnx.ai/onnx/technical/int4.html + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of packed uint4 values (size == 1 byte) + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of packed int4 values (size == 1 byte) } ONNXTensorElementDataType; // Synced with onnx TypeProto oneof @@ -226,10 +257,20 @@ typedef enum OrtErrorCode { ORT_EP_FAIL, } OrtErrorCode; +typedef enum OrtOpAttrType { + ORT_OP_ATTR_UNDEFINED = 0, + ORT_OP_ATTR_INT, + ORT_OP_ATTR_INTS, + ORT_OP_ATTR_FLOAT, + ORT_OP_ATTR_FLOATS, + ORT_OP_ATTR_STRING, + ORT_OP_ATTR_STRINGS, +} OrtOpAttrType; + //! @} #define ORT_RUNTIME_CLASS(X) \ struct Ort##X; \ - typedef struct Ort##X Ort##X; + typedef struct Ort##X Ort##X /** \addtogroup Global * ONNX Runtime C API @@ -240,21 +281,30 @@ ORT_RUNTIME_CLASS(Env); ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success ORT_RUNTIME_CLASS(MemoryInfo); ORT_RUNTIME_CLASS(IoBinding); -ORT_RUNTIME_CLASS(Session); //Don't call ReleaseSession from Dllmain (because session owns a thread pool) +ORT_RUNTIME_CLASS(Session); // Don't call ReleaseSession from Dllmain (because session owns a thread pool) ORT_RUNTIME_CLASS(Value); ORT_RUNTIME_CLASS(RunOptions); ORT_RUNTIME_CLASS(TypeInfo); ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); -ORT_RUNTIME_CLASS(SessionOptions); -ORT_RUNTIME_CLASS(CustomOpDomain); ORT_RUNTIME_CLASS(MapTypeInfo); ORT_RUNTIME_CLASS(SequenceTypeInfo); +ORT_RUNTIME_CLASS(OptionalTypeInfo); +ORT_RUNTIME_CLASS(SessionOptions); +ORT_RUNTIME_CLASS(CustomOpDomain); ORT_RUNTIME_CLASS(ModelMetadata); ORT_RUNTIME_CLASS(ThreadPoolParams); ORT_RUNTIME_CLASS(ThreadingOptions); ORT_RUNTIME_CLASS(ArenaCfg); ORT_RUNTIME_CLASS(PrepackedWeightsContainer); ORT_RUNTIME_CLASS(TensorRTProviderOptionsV2); +ORT_RUNTIME_CLASS(CUDAProviderOptionsV2); +ORT_RUNTIME_CLASS(CANNProviderOptions); +ORT_RUNTIME_CLASS(DnnlProviderOptions); +ORT_RUNTIME_CLASS(Op); +ORT_RUNTIME_CLASS(OpAttr); +ORT_RUNTIME_CLASS(Logger); +ORT_RUNTIME_CLASS(ShapeInferContext); +ORT_RUNTIME_CLASS(LoraAdapter); #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -263,16 +313,22 @@ typedef OrtStatus* OrtStatusPtr; #endif /** \brief Memory allocation interface -* -* Structure of function pointers that defines a memory allocator. This can be created and filled in by the user for custom allocators. -* -* When an allocator is passed to any function, be sure that the allocator object is not destroyed until the last allocated object using it is freed. -*/ + * + * Structure of function pointers that defines a memory allocator. This can be created and filled in by the user for custom allocators. + * + * When an allocator is passed to any function, be sure that the allocator object is not destroyed until the last allocated object using it is freed. + */ typedef struct OrtAllocator { uint32_t version; ///< Must be initialized to ORT_API_VERSION void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p); ///< Free a block of memory previously allocated with OrtAllocator::Alloc const struct OrtMemoryInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); ///< Return a pointer to an ::OrtMemoryInfo that describes this allocator + /** + * @brief Optional allocation function to use for memory allocations made during session initialization. + * Use this function if you want to separate allocations made by ORT during Run() calls from + * those made during session initialization. This allows for separate memory management strategies for these allocations. + */ + void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes } OrtAllocator; typedef void(ORT_API_CALL* OrtLoggingFunction)( @@ -280,10 +336,10 @@ typedef void(ORT_API_CALL* OrtLoggingFunction)( const char* message); /** \brief Graph optimization level -* -* Refer to https://www.onnxruntime.ai/docs/resources/graph-optimizations.html -* for an in-depth understanding of Graph Optimizations -*/ + * + * Refer to https://www.onnxruntime.ai/docs/performance/graph-optimizations.html#graph-optimization-levels + * for an in-depth understanding of the Graph Optimization Levels. + */ typedef enum GraphOptimizationLevel { ORT_DISABLE_ALL = 0, ORT_ENABLE_BASIC = 1, @@ -297,8 +353,8 @@ typedef enum ExecutionMode { } ExecutionMode; /** \brief Language projection identifiers -* /see OrtApi::SetLanguageProjection -*/ + * /see OrtApi::SetLanguageProjection + */ typedef enum OrtLanguageProjection { ORT_PROJECTION_C = 0, ORT_PROJECTION_CPLUSPLUS = 1, @@ -323,7 +379,7 @@ typedef enum OrtAllocatorType { } OrtAllocatorType; /** \brief Memory types for allocated memory, execution provider specific types should be extended in each provider. -*/ + */ // Whenever this struct is updated, please also update the MakeKey function in onnxruntime / core / framework / execution_provider.cc typedef enum OrtMemType { OrtMemTypeCPUInput = -2, ///< Any CPU memory used by non-CPU execution provider @@ -332,8 +388,16 @@ typedef enum OrtMemType { OrtMemTypeDefault = 0, ///< The default allocator for execution provider } OrtMemType; +/** \brief This mimics OrtDevice type constants so they can be returned in the API + */ +typedef enum OrtMemoryInfoDeviceType { + OrtMemoryInfoDeviceType_CPU = 0, + OrtMemoryInfoDeviceType_GPU = 1, + OrtMemoryInfoDeviceType_FPGA = 2 +} OrtMemoryInfoDeviceType; + /** \brief Algorithm to use for cuDNN Convolution Op -*/ + */ typedef enum OrtCudnnConvAlgoSearch { OrtCudnnConvAlgoSearchExhaustive, // expensive exhaustive benchmarking using cudnnFindConvolutionForwardAlgorithmEx OrtCudnnConvAlgoSearchHeuristic, // lightweight heuristic based search using cudnnGetConvolutionForwardAlgorithm_v7 @@ -341,126 +405,185 @@ typedef enum OrtCudnnConvAlgoSearch { } OrtCudnnConvAlgoSearch; /** \brief CUDA Provider Options -* -* \see OrtApi::SessionOptionsAppendExecutionProvider_CUDA -*/ + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_CUDA + */ typedef struct OrtCUDAProviderOptions { #ifdef __cplusplus - OrtCUDAProviderOptions() : device_id{}, cudnn_conv_algo_search{OrtCudnnConvAlgoSearchExhaustive}, gpu_mem_limit{SIZE_MAX}, arena_extend_strategy{}, do_copy_in_default_stream{1}, has_user_compute_stream{}, user_compute_stream{}, default_memory_arena_cfg{} {} + OrtCUDAProviderOptions() + : device_id{}, + cudnn_conv_algo_search{OrtCudnnConvAlgoSearchExhaustive}, + gpu_mem_limit{SIZE_MAX}, + arena_extend_strategy{}, + do_copy_in_default_stream{1}, + has_user_compute_stream{}, + user_compute_stream{}, + default_memory_arena_cfg{}, + tunable_op_enable{false}, + tunable_op_tuning_enable{false}, + tunable_op_max_tuning_duration_ms{} {} #endif /** \brief CUDA device Id - * Defaults to 0. - */ + * Defaults to 0. + */ int device_id; /** \brief CUDA Convolution algorithm search configuration. - * See enum OrtCudnnConvAlgoSearch for more details. - * Defaults to OrtCudnnConvAlgoSearchExhaustive. - */ + * See enum OrtCudnnConvAlgoSearch for more details. + * Defaults to OrtCudnnConvAlgoSearchExhaustive. + */ OrtCudnnConvAlgoSearch cudnn_conv_algo_search; /** \brief CUDA memory limit (To use all possible memory pass in maximum size_t) - * Defaults to SIZE_MAX. - * \note If a ::OrtArenaCfg has been applied, it will override this field - */ + * Defaults to SIZE_MAX. + * \note If a ::OrtArenaCfg has been applied, it will override this field + */ size_t gpu_mem_limit; /** \brief Strategy used to grow the memory arena - * 0 = kNextPowerOfTwo
- * 1 = kSameAsRequested
- * Defaults to 0. - * \note If a ::OrtArenaCfg has been applied, it will override this field - */ + * 0 = kNextPowerOfTwo
+ * 1 = kSameAsRequested
+ * Defaults to 0. + * \note If a ::OrtArenaCfg has been applied, it will override this field + */ int arena_extend_strategy; - /** \brief Flag indicating if copying needs to take place on the same stream as the compute stream in the CUDA EP - * 0 = Use separate streams for copying and compute. - * 1 = Use the same stream for copying and compute. - * Defaults to 1. - * WARNING: Setting this to 0 may result in data races for some models. - * Please see issue #4829 for more details. - */ + /** \brief Flag indicating if copying needs to take place on the same stream as the compute stream in the CUDA EP + * 0 = Use separate streams for copying and compute. + * 1 = Use the same stream for copying and compute. + * Defaults to 1. + * WARNING: Setting this to 0 may result in data races for some models. + * Please see issue #4829 for more details. + */ int do_copy_in_default_stream; /** \brief Flag indicating if there is a user provided compute stream - * Defaults to 0. - */ + * Defaults to 0. + */ int has_user_compute_stream; - /** \brief User provided compute stream. - * If provided, please set `has_user_compute_stream` to 1. - */ + /** \brief User provided compute stream. + * If provided, please set `has_user_compute_stream` to 1. + */ void* user_compute_stream; /** \brief CUDA memory arena configuration parameters - */ + */ OrtArenaCfg* default_memory_arena_cfg; + /** \brief Enable TunableOp for using. + * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. + * This option can be overridden by environment variable ORT_CUDA_TUNABLE_OP_ENABLE. + */ + int tunable_op_enable; + + /** \brief Enable TunableOp for tuning. + * Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default. + * This option can be overridden by environment variable ORT_CUDA_TUNABLE_OP_TUNING_ENABLE. + */ + int tunable_op_tuning_enable; + + /** \brief Max tuning duration time limit for each instance of TunableOp. + * Defaults to 0 to disable the limit. + */ + int tunable_op_max_tuning_duration_ms; + } OrtCUDAProviderOptions; /** \brief ROCM Provider Options -* -* \see OrtApi::SessionOptionsAppendExecutionProvider_ROCM -*/ + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_ROCM + */ typedef struct OrtROCMProviderOptions { #ifdef __cplusplus - OrtROCMProviderOptions() : device_id{}, miopen_conv_exhaustive_search{0}, gpu_mem_limit{SIZE_MAX}, arena_extend_strategy{}, do_copy_in_default_stream{1}, has_user_compute_stream{}, user_compute_stream{}, default_memory_arena_cfg{} {} + OrtROCMProviderOptions() + : device_id{}, + miopen_conv_exhaustive_search{0}, + gpu_mem_limit{SIZE_MAX}, + arena_extend_strategy{}, + do_copy_in_default_stream{1}, + has_user_compute_stream{}, + user_compute_stream{}, + default_memory_arena_cfg{}, + enable_hip_graph{false}, + tunable_op_enable{false}, + tunable_op_tuning_enable{false}, + tunable_op_max_tuning_duration_ms{} {} #endif /** \brief ROCM device Id - * Defaults to 0. - */ + * Defaults to 0. + */ int device_id; /** \brief ROCM MIOpen Convolution algorithm exaustive search option. - * Defaults to 0 (false). - */ + * Defaults to 0 (false). + */ int miopen_conv_exhaustive_search; /** \brief ROCM memory limit (To use all possible memory pass in maximum size_t) - * Defaults to SIZE_MAX. - * \note If a ::OrtArenaCfg has been applied, it will override this field - */ + * Defaults to SIZE_MAX. + * \note If a ::OrtArenaCfg has been applied, it will override this field + */ size_t gpu_mem_limit; /** \brief Strategy used to grow the memory arena - * 0 = kNextPowerOfTwo
- * 1 = kSameAsRequested
- * Defaults to 0. - * \note If a ::OrtArenaCfg has been applied, it will override this field - */ + * 0 = kNextPowerOfTwo
+ * 1 = kSameAsRequested
+ * Defaults to 0. + * \note If a ::OrtArenaCfg has been applied, it will override this field + */ int arena_extend_strategy; - /** \brief Flag indicating if copying needs to take place on the same stream as the compute stream in the ROCM EP - * 0 = Use separate streams for copying and compute. - * 1 = Use the same stream for copying and compute. - * Defaults to 1. - * WARNING: Setting this to 0 may result in data races for some models. - * Please see issue #4829 for more details. - */ + /** \brief Flag indicating if copying needs to take place on the same stream as the compute stream in the ROCM EP + * 0 = Use separate streams for copying and compute. + * 1 = Use the same stream for copying and compute. + * Defaults to 1. + * WARNING: Setting this to 0 may result in data races for some models. + * Please see issue #4829 for more details. + */ int do_copy_in_default_stream; /** \brief Flag indicating if there is a user provided compute stream - * Defaults to 0. - */ + * Defaults to 0. + */ int has_user_compute_stream; - /** \brief User provided compute stream. - * If provided, please set `has_user_compute_stream` to 1. - */ + /** \brief User provided compute stream. + * If provided, please set `has_user_compute_stream` to 1. + */ void* user_compute_stream; /** \brief ROCM memory arena configuration parameters - */ + */ OrtArenaCfg* default_memory_arena_cfg; + int enable_hip_graph; + + /** \brief Enable TunableOp for using. + * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. + * This option can be overridden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. + */ + int tunable_op_enable; + + /** \brief Enable TunableOp for tuning. + * Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default. + * This option can be overridden by environment variable ORT_ROCM_TUNABLE_OP_TUNING_ENABLE. + */ + int tunable_op_tuning_enable; + + /** \brief Max tuning duration time limit for each instance of TunableOp. + * Defaults to 0 to disable the limit. + */ + int tunable_op_max_tuning_duration_ms; + } OrtROCMProviderOptions; /** \brief TensorRT Provider Options -* -* \see OrtApi::SessionOptionsAppendExecutionProvider_TensorRT -*/ + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + */ typedef struct OrtTensorRTProviderOptions { int device_id; ///< CUDA device id (0 = default device) int has_user_compute_stream; // indicator of user specified CUDA compute stream. @@ -480,109 +603,162 @@ typedef struct OrtTensorRTProviderOptions { int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true const char* trt_engine_decryption_lib_path; // specify engine decryption library path int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true + // This is the legacy struct and don't add new fields here. + // For new field that can be represented by string, please add it in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h + // For non-string field, need to create a new separate api to handle it. } OrtTensorRTProviderOptions; +/** \brief MIGraphX Provider Options + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX + */ +typedef struct OrtMIGraphXProviderOptions { + int device_id; // hip device id. + int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true + int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true + const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name + int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true + const char* migraphx_save_model_path; // migraphx model path name + int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true + const char* migraphx_load_model_path; // migraphx model path name + bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false +} OrtMIGraphXProviderOptions; + /** \brief OpenVINO Provider Options -* -* \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO -*/ + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO + */ typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus - OrtOpenVINOProviderOptions() : device_type{}, enable_vpu_fast_compile{}, device_id{}, num_of_threads{}, use_compiled_network{}, blob_dump_path{}, context{} {} + OrtOpenVINOProviderOptions() : device_type{}, + enable_npu_fast_compile{}, + device_id{}, + num_of_threads{}, + cache_dir{}, + context{}, + enable_opencl_throttling{}, + enable_dynamic_shapes{} {} #endif /** \brief Device type string - * - * Valid settings are one of: "CPU_FP32", "GPU_FP32", "GPU_FP16", "MYRIAD_FP16", "VAD-M_FP16" or "VAD-F_FP32" - */ + * + * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16" + */ const char* device_type; - unsigned char enable_vpu_fast_compile; ///< 0 = disabled, nonzero = enabled + unsigned char enable_npu_fast_compile; const char* device_id; - size_t num_of_threads; ///< 0 = Use default number of threads - unsigned char use_compiled_network; ///< 0 = disabled, nonzero = enabled - const char* blob_dump_path; // path is set to empty by default + size_t num_of_threads; ///< 0 = Use default number of threads + const char* cache_dir; // path is set to empty by default void* context; + unsigned char enable_opencl_throttling; ///< 0 = disabled, nonzero = enabled + unsigned char enable_dynamic_shapes; ///< 0 = disabled, nonzero = enabled } OrtOpenVINOProviderOptions; struct OrtApi; typedef struct OrtApi OrtApi; +struct OrtTrainingApi; +typedef struct OrtTrainingApi OrtTrainingApi; + /** \brief The helper interface to get the right version of OrtApi -* -* Get a pointer to this structure through ::OrtGetApiBase -*/ + * + * Get a pointer to this structure through ::OrtGetApiBase + */ struct OrtApiBase { /** \brief Get a pointer to the requested version of the ::OrtApi - * - * \param[in] version Must be ::ORT_API_VERSION - * \return The ::OrtApi for the version requested, nullptr will be returned if this version is unsupported, for example when using a runtime - * older than the version created with this header file. - */ + * + * \param[in] version Must be ::ORT_API_VERSION + * \return The ::OrtApi for the version requested, nullptr will be returned if this version is unsupported, for example when using a runtime + * older than the version created with this header file. + * + * One can call GetVersionString() to get the version of the Onnxruntime library for logging + * and error reporting purposes. + */ const OrtApi*(ORT_API_CALL* GetApi)(uint32_t version)NO_EXCEPTION; - const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; ///< Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1") + + /** \brief Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1") + * + * \return UTF-8 encoded version string. Do not deallocate the returned buffer. + */ + const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; }; + typedef struct OrtApiBase OrtApiBase; /** \brief The Onnxruntime library's entry point to access the C API -* -* Call this to get the a pointer to an ::OrtApiBase -*/ + * + * Call this to get the a pointer to an ::OrtApiBase + */ ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION; /** \brief Thread work loop function -* -* Onnxruntime will provide the working loop on custom thread creation -* Argument is an onnxruntime built-in type which will be provided when thread pool calls OrtCustomCreateThreadFn -*/ + * + * Onnxruntime will provide the working loop on custom thread creation + * Argument is an onnxruntime built-in type which will be provided when thread pool calls OrtCustomCreateThreadFn + */ typedef void (*OrtThreadWorkerFn)(void* ort_worker_fn_param); -typedef const struct OrtCustomHandleType{ char __place_holder; }* OrtCustomThreadHandle; +typedef const struct OrtCustomHandleType { + char __place_holder; +}* OrtCustomThreadHandle; /** \brief Ort custom thread creation function -* -* The function should return a thread handle to be used in onnxruntime thread pools -* Onnxruntime will throw exception on return value of nullptr or 0, indicating that the function failed to create a thread -*/ + * + * The function should return a thread handle to be used in onnxruntime thread pools + * Onnxruntime will throw exception on return value of nullptr or 0, indicating that the function failed to create a thread + */ typedef OrtCustomThreadHandle (*OrtCustomCreateThreadFn)(void* ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void* ort_worker_fn_param); /** \brief Custom thread join function -* -* Onnxruntime thread pool destructor will call the function to join a custom thread. -* Argument ort_custom_thread_handle is the value returned by OrtCustomCreateThreadFn -*/ + * + * Onnxruntime thread pool destructor will call the function to join a custom thread. + * Argument ort_custom_thread_handle is the value returned by OrtCustomCreateThreadFn + */ typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle); +typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api); + +/** \brief Callback function for RunAsync + * + * \param[in] user_data User specific data that passed back to the callback + * \param[out] outputs On succeed, outputs host inference results, on error, the value will be nullptr + * \param[out] num_outputs Number of outputs, on error, the value will be zero + * \param[out] status On error, status will provide details + */ +typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status); + /** \brief The C API -* -* All C API functions are defined inside this structure as pointers to functions. -* Call OrtApiBase::GetApi to get a pointer to it -* -* \nosubgrouping -*/ + * + * All C API functions are defined inside this structure as pointers to functions. + * Call OrtApiBase::GetApi to get a pointer to it + * + * \nosubgrouping + */ struct OrtApi { /// \name OrtStatus /// @{ /** - * \brief Create an OrtStatus from a null terminated string - * - * \param[in] code - * \param[in] msg A null-terminated string. Its contents will be copied. - * \return A new OrtStatus object, must be destroyed with OrtApi::ReleaseStatus - */ + * \brief Create an OrtStatus from a null terminated string + * + * \param[in] code + * \param[in] msg A null-terminated string. Its contents will be copied. + * \return A new OrtStatus object, must be destroyed with OrtApi::ReleaseStatus + */ OrtStatus*(ORT_API_CALL* CreateStatus)(OrtErrorCode code, _In_ const char* msg)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; /** \brief Get OrtErrorCode from OrtStatus - * - * \param[in] status - * \return OrtErrorCode that \p status was created with - */ + * + * \param[in] status + * \return OrtErrorCode that \p status was created with + */ OrtErrorCode(ORT_API_CALL* GetErrorCode)(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; /** \brief Get error string from OrtStatus - * - * \param[in] status - * \return The error message inside the `status`. Do not free the returned value. - */ + * + * \param[in] status + * \return The error message inside the `status`. Do not free the returned value. + */ const char*(ORT_API_CALL* GetErrorMessage)(_In_ const OrtStatus* status)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; /// @} @@ -590,44 +766,49 @@ struct OrtApi { /// @{ /** \brief Create an OrtEnv - * - * \param[in] log_severity_level The log severity level. - * \param[in] logid The log identifier. - * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. + * \param[in] log_severity_level The log severity level. + * \param[in] logid The log identifier. + * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateEnv, OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); /** \brief Create an OrtEnv - * - * \param[in] logging_function A pointer to a logging function. - * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to - * `logging_function`. - * \param[in] log_severity_level The log severity level. - * \param[in] logid The log identifier. - * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, - OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. If you want to provide your + * own logging function, consider setting it using the SetUserLoggingFunction API instead. + * \param[in] logging_function A pointer to a logging function. + * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to + * `logging_function`. This parameter is optional. + * \param[in] log_severity_level The log severity level. + * \param[in] logid The log identifier. + * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateEnvWithCustomLogger, _In_ OrtLoggingFunction logging_function, _In_opt_ void* logger_param, + _In_ OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); /** \brief Enable Telemetry - * - * \note Telemetry events are on by default since they are lightweight - * \param[in] env - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \note Telemetry events are on by default since they are lightweight + * \param[in] env + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(EnableTelemetryEvents, _In_ const OrtEnv* env); /** \brief Disable Telemetry - * - * \see OrtApi::EnableTelemetryEvents - * \param[in] env - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \see OrtApi::EnableTelemetryEvents + * \param[in] env + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(DisableTelemetryEvents, _In_ const OrtEnv* env); /// @} @@ -635,14 +816,14 @@ struct OrtApi { /// @{ /** \brief Create an OrtSession from a model file - * - * \param[in] env - * \param[in] model_path - * \param[in] options - * \param[out] out Returned newly created OrtSession. Must be freed with OrtApi::ReleaseSession - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] env + * \param[in] model_path + * \param[in] options + * \param[out] out Returned newly created OrtSession. Must be freed with OrtApi::ReleaseSession + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ // TODO: document the path separator convention? '/' vs '\' // TODO: should specify the access characteristics of model_path. Is this read only during the // execution of CreateSession, or does the OrtSession retain a handle to the file/directory @@ -652,36 +833,36 @@ struct OrtApi { _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); /** \brief Create an OrtSession from memory - * - * \param[in] env - * \param[in] model_data - * \param[in] model_data_length - * \param[in] options - * \param[out] out Returned newly created OrtSession. Must be freed with OrtApi::ReleaseSession - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] env + * \param[in] model_data + * \param[in] model_data_length + * \param[in] options + * \param[out] out Returned newly created OrtSession. Must be freed with OrtApi::ReleaseSession + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); /** \brief Run the model in an ::OrtSession - * - * Will not return until the model run has completed. Multiple threads might be used to run the model based on - * the options in the ::OrtSession and settings used when creating the ::OrtEnv - * - * \param[in] session - * \param[in] run_options If nullptr, will use a default ::OrtRunOptions - * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names - * \param[in] inputs Array of ::OrtValue%s of the input values - * \param[in] input_len Number of elements in the input_names and inputs arrays - * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names - * \param[in] output_names_len Number of elements in the output_names and outputs array - * \param[out] outputs Array of ::OrtValue%s that the outputs are stored in. This can also be - * an array of nullptr values, in this case ::OrtValue objects will be allocated and pointers - * to them will be set into the `outputs` array. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Will not return until the model run has completed. Multiple threads might be used to run the model based on + * the options in the ::OrtSession and settings used when creating the ::OrtEnv + * + * \param[in] session + * \param[in] run_options If nullptr, will use a default ::OrtRunOptions + * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names + * \param[in] inputs Array of ::OrtValue%s of the input values + * \param[in] input_len Number of elements in the input_names and inputs arrays + * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names + * \param[in] output_names_len Number of elements in the output_names and outputs array + * \param[out] outputs Array of ::OrtValue%s that the outputs are stored in. This can also be + * an array of nullptr values, in this case ::OrtValue objects will be allocated and pointers + * to them will be set into the `outputs` array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(Run, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options, _In_reads_(input_len) const char* const* input_names, _In_reads_(input_len) const OrtValue* const* inputs, size_t input_len, @@ -693,182 +874,182 @@ struct OrtApi { /// @{ /** \brief Create an ::OrtSessionOptions object - * - * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these - * functions to enable them in the session:
- * OrtSessionOptionsAppendExecutionProvider_CPU
- * OrtSessionOptionsAppendExecutionProvider_CUDA
- * OrtSessionOptionsAppendExecutionProvider_(remaining providers...)
- * The order they are called indicates the preference order as well. In other words call this method - * on your most preferred execution provider first followed by the less preferred ones. - * If none are called Ort will use its internal CPU execution provider. - * - * \param[out] options The newly created OrtSessionOptions. Must be freed with OrtApi::ReleaseSessionOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these + * functions to enable them in the session:
+ * OrtSessionOptionsAppendExecutionProvider_CPU
+ * OrtSessionOptionsAppendExecutionProvider_CUDA
+ * OrtSessionOptionsAppendExecutionProvider_(remaining providers...)
+ * The order they are called indicates the preference order as well. In other words call this method + * on your most preferred execution provider first followed by the less preferred ones. + * If none are called Ort will use its internal CPU execution provider. + * + * \param[out] options The newly created OrtSessionOptions. Must be freed with OrtApi::ReleaseSessionOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateSessionOptions, _Outptr_ OrtSessionOptions** options); /** \brief Set filepath to save optimized model after graph level transformations - * - * \param[in] options - * \param[in] optimized_model_filepath - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] options + * \param[in] optimized_model_filepath + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetOptimizedModelFilePath, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* optimized_model_filepath); /** \brief Create a copy of an existing ::OrtSessionOptions - * - * \param[in] in_options OrtSessionOptions to copy - * \param[out] out_options Returned newly created ::OrtSessionOptions. Must be freed with OrtApi::ReleaseSessionOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] in_options OrtSessionOptions to copy + * \param[out] out_options Returned newly created ::OrtSessionOptions. Must be freed with OrtApi::ReleaseSessionOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CloneSessionOptions, _In_ const OrtSessionOptions* in_options, _Outptr_ OrtSessionOptions** out_options); /** \brief Set execution mode - * - * Controls whether you want to execute operators in your graph sequentially or in parallel. Usually when the model - * has many branches, setting this option to ExecutionMode.ORT_PARALLEL will give you better performance. - * See [docs/ONNX_Runtime_Perf_Tuning.md] for more details. - * - * \param[in] options - * \param[in] execution_mode - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Controls whether you want to execute operators in your graph sequentially or in parallel. Usually when the model + * has many branches, setting this option to ExecutionMode.ORT_PARALLEL will give you better performance. + * See [docs/ONNX_Runtime_Perf_Tuning.md] for more details. + * + * \param[in] options + * \param[in] execution_mode + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetSessionExecutionMode, _Inout_ OrtSessionOptions* options, ExecutionMode execution_mode); /** \brief Enable profiling for a session - * - * \param[in] options - * \param[in] profile_file_prefix - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] options + * \param[in] profile_file_prefix + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(EnableProfiling, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix); /** \brief Disable profiling for a session - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(DisableProfiling, _Inout_ OrtSessionOptions* options); /** \brief Enable the memory pattern optimization - * - * The idea is if the input shapes are the same, we could trace the internal memory allocation - * and generate a memory pattern for future request. So next time we could just do one allocation - * with a big chunk for all the internal memory allocation. - * \note Memory pattern optimization is only available when Sequential Execution mode is enabled (see OrtApi::SetSessionExecutionMode) - * - * \see OrtApi::DisableMemPattern - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * The idea is if the input shapes are the same, we could trace the internal memory allocation + * and generate a memory pattern for future request. So next time we could just do one allocation + * with a big chunk for all the internal memory allocation. + * \note Memory pattern optimization is only available when Sequential Execution mode is enabled (see OrtApi::SetSessionExecutionMode) + * + * \see OrtApi::DisableMemPattern + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(EnableMemPattern, _Inout_ OrtSessionOptions* options); /** \brief Disable the memory pattern optimization - * - * \see OrtApi::EnableMemPattern - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \see OrtApi::EnableMemPattern + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(DisableMemPattern, _Inout_ OrtSessionOptions* options); /** \brief Enable the memory arena on CPU - * - * Arena may pre-allocate memory for future usage. - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Arena may pre-allocate memory for future usage. + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(EnableCpuMemArena, _Inout_ OrtSessionOptions* options); /** \brief Disable the memory arena on CPU - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(DisableCpuMemArena, _Inout_ OrtSessionOptions* options); /** \brief Set session log id - * - * \param[in] options - * \param[in] logid The log identifier. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] options + * \param[in] logid The log identifier. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetSessionLogId, _Inout_ OrtSessionOptions* options, const char* logid); /** \brief Set session log verbosity level - * - * Applies to session load, initialization, etc - * - * \param[in] options - * \param[in] session_log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Applies to session load, initialization, etc + * + * \param[in] options + * \param[in] session_log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetSessionLogVerbosityLevel, _Inout_ OrtSessionOptions* options, int session_log_verbosity_level); /** \brief Set session log severity level - * - * \param[in] options - * \param[in] session_log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] options + * \param[in] session_log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetSessionLogSeverityLevel, _Inout_ OrtSessionOptions* options, int session_log_severity_level); /** \brief Set the optimization level to apply when loading a graph - * - * Please see https://www.onnxruntime.ai/docs/resources/graph-optimizations.html for an in-depth explanation - * \param[in,out] options The session options object - * \param[in] graph_optimization_level The optimization level - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Please see https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html for an in-depth explanation + * \param[in,out] options The session options object + * \param[in] graph_optimization_level The optimization level + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options, GraphOptimizationLevel graph_optimization_level); /** \brief Sets the number of threads used to parallelize the execution within nodes - * - * When running a single node operation, ex. add, this sets the maximum number of threads to use. - * - * \note If built with OpenMP, this has no effect on the number of threads used. In this case - * use the OpenMP env variables to configure the number of intra op num threads. - * - * \param[in] options - * \param[in] intra_op_num_threads Number of threads to use
- * A value of 0 will use the default number of threads
- * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * When running a single node operation, ex. add, this sets the maximum number of threads to use. + * + * \note If built with OpenMP, this has no effect on the number of threads used. In this case + * use the OpenMP env variables to configure the number of intra op num threads. + * + * \param[in] options + * \param[in] intra_op_num_threads Number of threads to use
+ * A value of 0 will use the default number of threads
+ * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetIntraOpNumThreads, _Inout_ OrtSessionOptions* options, int intra_op_num_threads); /** \brief Sets the number of threads used to parallelize the execution of the graph - * - * If nodes can be run in parallel, this sets the maximum number of threads to use to run them in parallel. - * - * \note If sequential execution is enabled this value is ignored, it acts as if it was set to 1. - * - * \param[in] options - * \param[in] inter_op_num_threads Number of threads to use
- * A value of 0 will use the default number of threads
- * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If nodes can be run in parallel, this sets the maximum number of threads to use to run them in parallel. + * + * \note If sequential execution is enabled this value is ignored, it acts as if it was set to 1. + * + * \param[in] options + * \param[in] inter_op_num_threads Number of threads to use
+ * A value of 0 will use the default number of threads
+ * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetInterOpNumThreads, _Inout_ OrtSessionOptions* options, int inter_op_num_threads); /// @} @@ -876,23 +1057,23 @@ struct OrtApi { /// @{ /** \brief Create a custom op domain - * - * \param[in] domain - * \param[out] out Newly created domain. Must be freed with OrtApi::ReleaseCustomOpDomain - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] domain + * \param[out] out Newly created domain. Must be freed with OrtApi::ReleaseCustomOpDomain + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateCustomOpDomain, _In_ const char* domain, _Outptr_ OrtCustomOpDomain** out); /** \brief Add a custom op to a custom op domain - * - * \note The OrtCustomOp* pointer must remain valid until the ::OrtCustomOpDomain using it is released - * - * \param[in] custom_op_domain - * \param[in] op - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \note The OrtCustomOp* pointer must remain valid until the ::OrtCustomOpDomain using it is released + * + * \param[in] custom_op_domain + * \param[in] op + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ const OrtCustomOp* op); /// @} @@ -900,134 +1081,136 @@ struct OrtApi { /// @{ /** \brief Add custom op domain to a session options - * - * \note The OrtCustomOpDomain* must not be deleted until all sessions using it are released - * - * \param[in] options - * \param[in] custom_op_domain - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \note The OrtCustomOpDomain* must not be deleted until all sessions using it are released + * + * \param[in] options + * \param[in] custom_op_domain + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(AddCustomOpDomain, _Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain); - /** \brief Register custom ops from a shared library - * - * Loads a shared library (dll on windows, so on linux, etc) named 'library_path' and looks for this entry point: - * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); - * It then passes in the provided session options to this function along with the api base. - * The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in - * session options are destroyed, or if an error occurs and it is non null. - * - * \param[in] options - * \param[in] library_path - * \param[out] library_handle OS specific handle to the loaded library (Use FreeLibrary on Windows, dlclose on Linux, etc.. to unload) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, void** library_handle); + /** \deprecated Use OrtApi::RegisterCustomOpsLibrary_V2. + * + * Registers custom ops from a shared library. + * + * Loads a shared library (dll on windows, so on linux, etc) named 'library_path' and looks for this entry point: + * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); + * It then passes in the provided session options to this function along with the api base. + * The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in + * session options are destroyed, or if an error occurs and it is non null. + * + * \param[in] options + * \param[in] library_path + * \param[out] library_handle OS specific handle to the loaded library (Use FreeLibrary on Windows, dlclose on Linux, etc.. to unload) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, _Outptr_ void** library_handle); /// @} /// \name OrtSession /// @{ /** \brief Get input count for a session - * - * This number must also match the number of inputs passed to OrtApi::Run - * - * \see OrtApi::SessionGetInputTypeInfo, OrtApi::SessionGetInputName, OrtApi::Session - * - * \param[in] session - * \param[out] out Number of inputs - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * This number must also match the number of inputs passed to OrtApi::Run + * + * \see OrtApi::SessionGetInputTypeInfo, OrtApi::SessionGetInputName, OrtApi::Session + * + * \param[in] session + * \param[out] out Number of inputs + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionGetInputCount, _In_ const OrtSession* session, _Out_ size_t* out); /** \brief Get output count for a session - * - * This number must also match the number of outputs returned by OrtApi::Run - * - * \see OrtApi::SessionGetOutputTypeInfo, OrtApi::SessionGetOutputName, OrtApi::Session - * - * \param[in] session - * \param[out] out Number of outputs - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOutputCount, _In_ const OrtSession* session, _Out_ size_t* out); + * + * This number must also match the number of outputs returned by OrtApi::Run + * + * \see OrtApi::SessionGetOutputTypeInfo, OrtApi::SessionGetOutputName, OrtApi::Session + * + * \param[in] session + * \param[out] out Number of outputs + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOutputCount, _In_ const OrtSession* session, _Out_ size_t* out); /** \brief Get overridable initializer count - * - * \see OrtApi::SessionGetOverridableInitializerTypeInfo, OrtApi::SessionGetOverridableInitializerName - * - * \param[in] session - * \param[in] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \see OrtApi::SessionGetOverridableInitializerTypeInfo, OrtApi::SessionGetOverridableInitializerName + * + * \param[in] session + * \param[in] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionGetOverridableInitializerCount, _In_ const OrtSession* session, _Out_ size_t* out); /** \brief Get input type information - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetInputCount returns (exclusive) - * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetInputCount returns (exclusive) + * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionGetInputTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); /** \brief Get output type information - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOutputCount returns (exclusive) - * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOutputCount returns (exclusive) + * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionGetOutputTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); /** \brief Get overridable initializer type information - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOverridableInitializerCount returns (exclusive) - * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOverridableInitializerCount returns (exclusive) + * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionGetOverridableInitializerTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); /** \brief Get input name - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetInputCount returns (exclusive) - * \param[in] allocator - * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetInputCount returns (exclusive) + * \param[in] allocator + * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionGetInputName, _In_ const OrtSession* session, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); /** \brief Get output name - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOutputCount returns (exclusive) - * \param[in] allocator - * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOutputCount returns (exclusive) + * \param[in] allocator + * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionGetOutputName, _In_ const OrtSession* session, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); /** \brief Get overridable initializer name - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOverridableInitializerCount returns (exclusive) - * \param[in] allocator - * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOverridableInitializerCount returns (exclusive) + * \param[in] allocator + * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionGetOverridableInitializerName, _In_ const OrtSession* session, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); @@ -1036,11 +1219,11 @@ struct OrtApi { /// @{ /** \brief Create an OrtRunOptions - * - * \param[out] out Returned newly created ::OrtRunOptions. Must be freed with OrtApi::ReleaseRunOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[out] out Returned newly created ::OrtRunOptions. Must be freed with OrtApi::ReleaseRunOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateRunOptions, _Outptr_ OrtRunOptions** out); /** \brief Set per-run log verbosity level @@ -1109,23 +1292,23 @@ struct OrtApi { ORT_API2_STATUS(RunOptionsGetRunTag, _In_ const OrtRunOptions* options, _Out_ const char** run_tag); /** \brief Set terminate flag - * - * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error. - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error. + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(RunOptionsSetTerminate, _Inout_ OrtRunOptions* options); /** \brief Clears the terminate flag - * - * Used so the OrtRunOptions instance can be used in a new OrtApi::Run call without it instantly terminating - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Used so the OrtRunOptions instance can be used in a new OrtApi::Run call without it instantly terminating + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options); /// @} @@ -1133,17 +1316,17 @@ struct OrtApi { /// @{ /** \brief Create a tensor - * - * Create a tensor using a supplied ::OrtAllocator - * - * \param[in] allocator - * \param[in] shape Tensor shape - * \param[in] shape_len Number of elements in `shape` - * \param[in] type - * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Create a tensor using a supplied ::OrtAllocator + * + * \param[in] allocator + * \param[in] shape Pointer to the tensor shape dimensions. + * \param[in] shape_len The number of tensor shape dimensions. + * \param[in] type + * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out); @@ -1152,81 +1335,81 @@ struct OrtApi { * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. * p_data is owned by caller. ReleaseValue won't release p_data. * - * \param[in] info - * \param[in] p_data - * \param[in] p_data_len - * \param[in] shape - * \param[in] shape_len - * \param[in] type + * \param[in] info Memory description of where the p_data buffer resides (CPU vs GPU etc). + * \param[in] p_data Pointer to the data buffer. + * \param[in] p_data_len The number of bytes in the data buffer. + * \param[in] shape Pointer to the tensor shape dimensions. + * \param[in] shape_len The number of tensor shape dimensions. + * \param[in] type The data type. * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out); /** \brief Return if an ::OrtValue is a tensor type - * - * \param[in] value A tensor type (string tensors are not supported) - * \param[out] out Set to 1 iff ::OrtValue is a tensor, 0 otherwise - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] value A tensor type (string tensors are not supported) + * \param[out] out Set to 1 iff ::OrtValue is a tensor, 0 otherwise + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(IsTensor, _In_ const OrtValue* value, _Out_ int* out); /** \brief Get a pointer to the raw data inside a tensor - * - * Used to read/write/modify the internal tensor data directly. - * \note The returned pointer is valid until the \p value is destroyed. - * - * \param[in] value A tensor type (string tensors are not supported) - * \param[out] out Filled in with a pointer to the internal storage - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Used to read/write/modify the internal tensor data directly. + * \note The returned pointer is valid until the \p value is destroyed. + * + * \param[in] value A tensor type (string tensors are not supported) + * \param[out] out Filled in with a pointer to the internal storage + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetTensorMutableData, _In_ OrtValue* value, _Outptr_ void** out); /** \brief Set all strings at once in a string tensor - * - * \param[in,out] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING - * \param[in] s An array of strings. Each string in this array must be null terminated. - * \param[in] s_len Count of strings in s (Must match the size of \p value's tensor shape) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in,out] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING + * \param[in] s An array of strings. Each string in this array must be null terminated. + * \param[in] s_len Count of strings in s (Must match the size of \p value's tensor shape) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(FillStringTensor, _Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len); /** \brief Get total byte length for all strings in a string tensor - * - * Typically used with OrtApi::GetStringTensorContent - * - * \param[in] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING - * \param[out] len Total byte length of all strings (does not include trailing nulls) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Typically used with OrtApi::GetStringTensorContent + * + * \param[in] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING + * \param[out] len Total byte length of all strings (does not include trailing nulls) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* len); /** \brief Get all strings from a string tensor - * - * An example of the results:
- * Given \p value is a string tensor with the strings { "This" "is" "a" "test" }
- * \p s must have a size of 11 bytes
- * \p offsets must have 4 elements
- * After the call, these values will be filled in:
- * \p s will contain "Thisisatest"
- * \p offsets will contain { 0, 4, 6, 7 }
- * The length of the last string is just s_len - offsets[last] - * - * \param[in] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING - * \param[in] s Buffer to sequentially write all tensor strings to. Each string is NOT null-terminated. - * \param[in] s_len Number of bytes of buffer pointed to by \p s (Get it from OrtApi::GetStringTensorDataLength) - * \param[out] offsets Array of start offsets into the strings written to \p s - * \param[in] offsets_len Number of elements in offsets - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * An example of the results:
+ * Given \p value is a string tensor with the strings { "This" "is" "a" "test" }
+ * \p s must have a size of 11 bytes
+ * \p offsets must have 4 elements
+ * After the call, these values will be filled in:
+ * \p s will contain "Thisisatest"
+ * \p offsets will contain { 0, 4, 6, 7 }
+ * The length of the last string is just s_len - offsets[last] + * + * \param[in] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING + * \param[in] s Buffer to sequentially write all tensor strings to. Each string is NOT null-terminated. + * \param[in] s_len Number of bytes of buffer pointed to by \p s (Get it from OrtApi::GetStringTensorDataLength) + * \param[out] offsets Array of start offsets into the strings written to \p s + * \param[in] offsets_len Number of elements in offsets + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s, size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len); @@ -1235,22 +1418,23 @@ struct OrtApi { /// @{ /** \brief Get ::OrtTensorTypeAndShapeInfo from an ::OrtTypeInfo - * - * \param[in] type_info - * \param[out] out Do not free this value, it will be valid until type_info is freed. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] type_info + * \param[out] out Do not free this value, it will be valid until type_info is freed. + * If type_info does not represent tensor, this value will be set to nullptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo* type_info, _Outptr_result_maybenull_ const OrtTensorTypeAndShapeInfo** out); /** \brief Get ::ONNXType from ::OrtTypeInfo - * - * \param[in] type_info - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] type_info + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ enum ONNXType* out); /// @} @@ -1258,93 +1442,93 @@ struct OrtApi { /// @{ /** \brief Create an ::OrtTensorTypeAndShapeInfo object - * - * \param[out] out Returns newly created ::OrtTensorTypeAndShapeInfo. Must be freed with OrtApi::ReleaseTensorTypeAndShapeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[out] out Returns newly created ::OrtTensorTypeAndShapeInfo. Must be freed with OrtApi::ReleaseTensorTypeAndShapeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out); /** \brief Set element type in ::OrtTensorTypeAndShapeInfo - * - * \param[in] info - * \param[in] type - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] info + * \param[in] type + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo* info, enum ONNXTensorElementDataType type); /** \brief Set shape information in ::OrtTensorTypeAndShapeInfo - * - * \param[in] info - * \param[in] dim_values Array with `dim_count` elements. Can contain negative values. - * \param[in] dim_count Number of elements in `dim_values` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] info + * \param[in] dim_values Array with `dim_count` elements. Can contain negative values. + * \param[in] dim_count Number of elements in `dim_values` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); /** \brief Get element type in ::OrtTensorTypeAndShapeInfo - * - * \see OrtApi::SetTensorElementType - * - * \param[in] info - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \see OrtApi::SetTensorElementType + * + * \param[in] info + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetTensorElementType, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ enum ONNXTensorElementDataType* out); /** \brief Get dimension count in ::OrtTensorTypeAndShapeInfo - * - * \see OrtApi::GetDimensions - * - * \param[in] info - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \see OrtApi::GetDimensions + * + * \param[in] info + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); /** \brief Get dimensions in ::OrtTensorTypeAndShapeInfo - * - * \param[in] info - * \param[out] dim_values Array with `dim_values_length` elements. On return, filled with the dimensions stored in the ::OrtTensorTypeAndShapeInfo - * \param[in] dim_values_length Number of elements in `dim_values`. Use OrtApi::GetDimensionsCount to get this value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] info + * \param[out] dim_values Array with `dim_values_length` elements. On return, filled with the dimensions stored in the ::OrtTensorTypeAndShapeInfo + * \param[in] dim_values_length Number of elements in `dim_values`. Use OrtApi::GetDimensionsCount to get this value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); /** \brief Get symbolic dimension names in ::OrtTensorTypeAndShapeInfo - * - * \param[in] info - * \param[in] dim_params Array with `dim_params_length` elements. On return filled with pointers to null terminated strings of the dimension names - * \param[in] dim_params_length Number of elements in `dim_params`. Use OrtApi::GetDimensionsCount to get this value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] info + * \param[in] dim_params Array with `dim_params_length` elements. On return filled with pointers to null terminated strings of the dimension names + * \param[in] dim_params_length Number of elements in `dim_params`. Use OrtApi::GetDimensionsCount to get this value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_writes_all_(dim_params_length) const char* dim_params[], size_t dim_params_length); /** \brief Get total number of elements in a tensor shape from an ::OrtTensorTypeAndShapeInfo - * - * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). - * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. - * - * Examples:
- * [] = 1
- * [1,3,4] = 12
- * [2,0,4] = 0
- * [-1,3,4] = -1
- * - * \param[in] info - * \param[out] out Number of elements - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ * + * \param[in] info + * \param[out] out Number of elements + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); /// @} @@ -1352,30 +1536,30 @@ struct OrtApi { /// @{ /** \brief Get type and shape information from a tensor ::OrtValue - * - * \param[in] value Must be a tensor (not a map/sequence/etc) or will return failure - * \param[out] out Newly created ::OrtTensorTypeAndShapeInfo. Must be freed with OrtApi::ReleaseTensorTypeAndShapeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] value Must be a tensor (not a map/sequence/etc) or will return failure + * \param[out] out Newly created ::OrtTensorTypeAndShapeInfo. Must be freed with OrtApi::ReleaseTensorTypeAndShapeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out); /** \brief Get type information of an OrtValue - * - * \param[in] value - * \param[out] out Newly created ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] value + * \param[out] out Newly created ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out); /** \brief Get ONNXType of an ::OrtValue - * - * \param[in] value - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] value + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out); /// @} @@ -1383,62 +1567,62 @@ struct OrtApi { /// @{ /** \brief Create an ::OrtMemoryInfo - * - * \param[in] name - * \param[in] type - * \param[in] id - * \param[in] mem_type - * \param[out] out Newly created ::OrtMemoryInfo. Must be freed with OrtAPi::ReleaseMemoryInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] name + * \param[in] type + * \param[in] id + * \param[in] mem_type + * \param[out] out Newly created ::OrtMemoryInfo. Must be freed with OrtAPi::ReleaseMemoryInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateMemoryInfo, _In_ const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, _Outptr_ OrtMemoryInfo** out); /** \brief Create an ::OrtMemoryInfo for CPU memory - * - * Special case version of OrtApi::CreateMemoryInfo for CPU based memory. Same as using OrtApi::CreateMemoryInfo with name = "Cpu" and id = 0. - * - * \param[in] type - * \param[in] mem_type - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Special case version of OrtApi::CreateMemoryInfo for CPU based memory. Same as using OrtApi::CreateMemoryInfo with name = "Cpu" and id = 0. + * + * \param[in] type + * \param[in] mem_type + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type, _Outptr_ OrtMemoryInfo** out); /** \brief Compare ::OrtMemoryInfo objects for equality - * - * Compares all settings of each ::OrtMemoryInfo for equality - * - * \param[in] info1 - * \param[in] info2 - * \param[out] out Set to 0 if equal, -1 if not equal - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Compares all settings of each ::OrtMemoryInfo for equality + * + * \param[in] info1 + * \param[in] info2 + * \param[out] out Set to 0 if equal, -1 if not equal + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CompareMemoryInfo, _In_ const OrtMemoryInfo* info1, _In_ const OrtMemoryInfo* info2, _Out_ int* out); /** \brief Get name from ::OrtMemoryInfo - * - * \param[in] ptr - * \param[out] out Writes null terminated string to this pointer. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtMemoryInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] ptr + * \param[out] out Writes null terminated string to this pointer. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtMemoryInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out); /** \brief Get the id from ::OrtMemoryInfo - */ + */ ORT_API2_STATUS(MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out); /** \brief Get the ::OrtMemType from ::OrtMemoryInfo - */ + */ ORT_API2_STATUS(MemoryInfoGetMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtMemType* out); /** \brief Get the ::OrtAllocatorType from ::OrtMemoryInfo - */ + */ ORT_API2_STATUS(MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out); /// @} @@ -1453,13 +1637,13 @@ struct OrtApi { ORT_API2_STATUS(AllocatorGetInfo, _In_ const OrtAllocator* ort_allocator, _Outptr_ const struct OrtMemoryInfo** out); /** \brief Get the default allocator - * - * The default allocator is a CPU based, non-arena. Always returns the same pointer to the same default allocator. - * - * \param[out] out Returned value should NOT be freed - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * The default allocator is a CPU based, non-arena. Always returns the same pointer to the same default allocator. + * + * \param[out] out Returned value should NOT be freed + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetAllocatorWithDefaultOptions, _Outptr_ OrtAllocator** out); /// @} @@ -1467,16 +1651,16 @@ struct OrtApi { /// @{ /** \brief Override session symbolic dimensions - * - * Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable - * optimizations that can take advantage of fixed values (such as memory planning, etc) - * - * \param[in] options - * \param[in] dim_denotation - * \param[in] dim_value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable + * optimizations that can take advantage of fixed values (such as memory planning, etc) + * + * \param[in] options + * \param[in] dim_denotation + * \param[in] dim_value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* dim_denotation, _In_ int64_t dim_value); @@ -1485,196 +1669,214 @@ struct OrtApi { /// @{ /* Internal information (not seen in Doxygen) - * - * APIs to support non-tensor types - map and sequence. - * Currently only the following types are supported - * Note: the following types should be kept in sync with data_types.h - * Map types - * ========= - * std::map - * std::map - * std::map - * std::map - * std::map - * std::map - * std::map - * std::map - * - * Sequence types - * ============== - * std::vector - * std::vector - * std::vector - * std::vector - * std::vector> - * std::vector - */ + * + * APIs to support non-tensor types - map and sequence. + * Currently only the following types are supported + * Note: the following types should be kept in sync with data_types.h + * Map types + * ========= + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * + * Sequence types + * ============== + * std::vector + * std::vector + * std::vector + * std::vector + * std::vector> + * std::vector + */ /** \brief Get non tensor data from an ::OrtValue - * - * If `value` is of type ONNX_TYPE_MAP, you need to retrieve the keys and values - * separately. Use index=0 to retrieve keys and index=1 to retrieve values. - * If `value` is of type ONNX_TYPE_SEQUENCE, use index to retrieve the index'th element - * of the sequence. - * - * \param[in] value - * \param[in] index See above for usage based on `value` type - * \param[in] allocator Allocator used to allocate ::OrtValue - * \param[out] out Created ::OrtValue that holds the element requested. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If `value` is of type ONNX_TYPE_MAP, you need to retrieve the keys and values + * separately. Use index=0 to retrieve keys and index=1 to retrieve values. + * If `value` is of type ONNX_TYPE_SEQUENCE, use index to retrieve the index'th element + * of the sequence. + * + * \param[in] value + * \param[in] index See above for usage based on `value` type + * \param[in] allocator Allocator used to allocate ::OrtValue + * \param[out] out Created ::OrtValue that holds the element requested. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetValue, _In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out); /** \brief Get non tensor value count from an ::OrtValue - * - * If `value` is of type ONNX_TYPE_MAP 2 will always be returned. For ONNX_TYPE_SEQUENCE - * the number of elements in the sequence will be returned - * - * \param[in] value - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If `value` is of type ONNX_TYPE_MAP 2 will always be returned. For ONNX_TYPE_SEQUENCE + * the number of elements in the sequence will be returned + * + * \param[in] value + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetValueCount, _In_ const OrtValue* value, _Out_ size_t* out); /** \brief Create a map or sequence ::OrtValue - * - * To construct a map (ONNX_TYPE_MAP), use num_values = 2 and `in` should be an array of 2 ::OrtValue%s - * representing keys and values.
- * - * To construct a sequence (ONNX_TYPE_SEQUENCE), use num_values = N where N is the number of the elements in the - * sequence. 'in' should be an array of N ::OrtValue%s. - * - * \param[in] in See above for details - * \param[in] num_values - * \param[in] value_type Must be either ONNX_TYPE_MAP or ONNX_TYPE_SEQUENCE - * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * To construct a map (ONNX_TYPE_MAP), use num_values = 2 and `in` should be an array of 2 ::OrtValue%s + * representing keys and values.
+ * + * To construct a sequence (ONNX_TYPE_SEQUENCE), use num_values = N where N is the number of the elements in the + * sequence. 'in' should be an array of N ::OrtValue%s. + * + * \param[in] in See above for details + * \param[in] num_values + * \param[in] value_type Must be either ONNX_TYPE_MAP or ONNX_TYPE_SEQUENCE + * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateValue, _In_reads_(num_values) const OrtValue* const* in, size_t num_values, enum ONNXType value_type, _Outptr_ OrtValue** out); /** \brief Create an opaque (custom user defined type) ::OrtValue - * - * Constructs an ::OrtValue that contains a value of non-standard type created for - * experiments or while awaiting standardization. ::OrtValue in this case would contain - * an internal representation of the Opaque type. Opaque types are distinguished from - * each other by two strings 1) domain and 2) type name. The combination of the two - * must be unique, so the type representation is properly identified internally. The combination - * must be properly registered from within ORT at both compile/run time or by another API. - * - * To construct the ::OrtValue pass domain and type names, also a pointer to a data container - * the type of which must be known to both ORT and the client program. That data container may or may - * not match the internal representation of the Opaque type. The sizeof(data_container) is passed for - * verification purposes. - * - * \param[in] domain_name Null terminated string of the domain name - * \param[in] type_name Null terminated string of the type name - * \param[in] data_container User pointer Data to populate ::OrtValue - * \param[in] data_container_size Size in bytes of what `data_container` points to - * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Constructs an ::OrtValue that contains a value of non-standard type created for + * experiments or while awaiting standardization. ::OrtValue in this case would contain + * an internal representation of the Opaque type. Opaque types are distinguished from + * each other by two strings 1) domain and 2) type name. The combination of the two + * must be unique, so the type representation is properly identified internally. The combination + * must be properly registered from within ORT at both compile/run time or by another API. + * + * To construct the ::OrtValue pass domain and type names, also a pointer to a data container + * the type of which must be known to both ORT and the client program. That data container may or may + * not match the internal representation of the Opaque type. The sizeof(data_container) is passed for + * verification purposes. + * + * \param[in] domain_name Null terminated string of the domain name + * \param[in] type_name Null terminated string of the type name + * \param[in] data_container User pointer Data to populate ::OrtValue + * \param[in] data_container_size Size in bytes of what `data_container` points to + * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateOpaqueValue, _In_z_ const char* domain_name, _In_z_ const char* type_name, _In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out); /** \brief Get internal data from an opaque (custom user defined type) ::OrtValue - * - * Copies internal data from an opaque value into a user provided buffer - * - * \see OrtApi::CreateOpaqueValue - * - * \param[in] domain_name Null terminated string of the domain name - * \param[in] type_name Null terminated string of the type name - * \param[in] in The opaque ::OrtValue - * \param[out] data_container Buffer to copy data into - * \param[out] data_container_size Size in bytes of the buffer pointed to by data_container. Must match the size of the internal buffer. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Copies internal data from an opaque value into a user provided buffer + * + * \see OrtApi::CreateOpaqueValue + * + * \param[in] domain_name Null terminated string of the domain name + * \param[in] type_name Null terminated string of the type name + * \param[in] in The opaque ::OrtValue + * \param[out] data_container Buffer to copy data into + * \param[out] data_container_size Size in bytes of the buffer pointed to by data_container. Must match the size of the internal buffer. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, _In_ const OrtValue* in, _Out_ void* data_container, size_t data_container_size); /// @} /// \name OrtKernelInfo + /// Custom operator APIs. /// @{ /** \brief Get a float stored as an attribute in the graph node - * - * \param[in] info ::OrtKernelInfo instance - * \param[in] name Null terminated string of the name of the attribute - * \param[out] out Pointer to memory where the attribute will be stored - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] info ::OrtKernelInfo instance + * \param[in] name Null terminated string of the name of the attribute + * \param[out] out Pointer to memory where the attribute will be stored + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out); /** \brief Fetch a 64-bit int stored as an attribute in the graph node - * - * \param[in] info ::OrtKernelInfo instance - * \param[in] name Null terminated string of the name of the attribute - * \param[out] out Pointer to memory where the attribute will be stored - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] info ::OrtKernelInfo instance + * \param[in] name Null terminated string of the name of the attribute + * \param[out] out Pointer to memory where the attribute will be stored + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out); /** \brief Fetch a string stored as an attribute in the graph node - * - * If `out` is nullptr, the value of `size` is set to the true size of the string - * attribute, and a success status is returned. - * - * If the `size` parameter is greater than or equal to the actual string attribute's size, - * the value of `size` is set to the true size of the string attribute, the provided memory - * is filled with the attribute's contents, and a success status is returned. - * - * If the `size` parameter is less than the actual string attribute's size and `out` - * is not nullptr, the value of `size` is set to the true size of the string attribute - * and a failure status is returned.) - * - * \param[in] info ::OrtKernelInfo instance - * \param[in] name Null terminated string of the name of the attribute - * \param[out] out Pointer to memory where the attribute will be stored - * \param[in,out] size See above comments for details - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If `out` is nullptr, the value of `size` is set to the true size of the string + * attribute, and a success status is returned. + * + * If the `size` parameter is greater than or equal to the actual string attribute's size, + * the value of `size` is set to the true size of the string attribute, the provided memory + * is filled with the attribute's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string attribute's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string attribute + * and a failure status is returned.) + * + * \param[in] info ::OrtKernelInfo instance + * \param[in] name Null terminated string of the name of the attribute + * \param[out] out Pointer to memory where the attribute will be stored + * \param[in,out] size See above comments for details + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, _Inout_ size_t* size); /// @} /// \name OrtKernelContext + /// Custom operator APIs. /// @{ /** \brief Used for custom operators, get the input count of a kernel - * - * \see ::OrtCustomOp - */ + * + * \see ::OrtCustomOp + */ ORT_API2_STATUS(KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); /** \brief Used for custom operators, get the output count of a kernel - * - * \see ::OrtCustomOp - */ + * + * \see ::OrtCustomOp + */ ORT_API2_STATUS(KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); /** \brief Used for custom operators, get an input of a kernel - * - * \see ::OrtCustomOp - */ + * + * The function attempts fetches the input of the kernel. If the input is optional + * and not present, the function returns success and out is set to nullptr. + * + * \param[in] context ::OrtKernelContext instance + * \param[in] index See KernelContext_GetInputCount for boundaries check. + * \param[out] out OrtValue if the input is present otherwise is set nullptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out); /** \brief Used for custom operators, get an output of a kernel - * - * \see ::OrtCustomOp - */ + * + * The function attempts fetches the output of the kernel. If the output is optional + * and not present, the function returns success and out is set to nullptr. + * + * \param[in] context ::OrtKernelContext instance + * \param[in] index See KernelContext_GetOutputCount for boundaries check. + * \param[in] dim_values output dimensions + * \param[in] dim_count number of dimensions + * \param[out] out a ptr to OrtValue to output otherwise set to nullptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out); @@ -1693,7 +1895,7 @@ struct OrtApi { /// @} /// \name OrtSession /// @{ - ORT_CLASS_RELEASE(Session); //Don't call ReleaseSession from Dllmain (because session owns a thread pool) + ORT_CLASS_RELEASE(Session); // Don't call ReleaseSession from Dllmain (because session owns a thread pool) /// @} /// \name OrtValue /// @{ @@ -1724,47 +1926,49 @@ struct OrtApi { /// @{ /** \brief Get denotation from type information - * - * Augments ::OrtTypeInfo to return denotations on the type. - * - * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor. - * - * \param[in] type_info - * \param[out] denotation Pointer to the null terminated denotation string is written to this pointer. This pointer is valid until the object is destroyed or the name is changed, do not free. - * \param[out] len Length in bytes of the string returned in `denotation` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Augments ::OrtTypeInfo to return denotations on the type. + * + * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor. + * + * \param[in] type_info + * \param[out] denotation Pointer to the null terminated denotation string is written to this pointer. This pointer is valid until the object is destroyed or the name is changed, do not free. + * \param[out] len Length in bytes of the string returned in `denotation` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const char** const denotation, _Out_ size_t* len); /** \brief Get detailed map information from an ::OrtTypeInfo - * - * This augments ::OrtTypeInfo to return an ::OrtMapTypeInfo when the type is a map. - * The OrtMapTypeInfo has additional information about the map's key type and value type. - * - * This is used by WinML to support model reflection APIs. - * - * \param[out] type_info - * \param[out] out A pointer to the ::OrtMapTypeInfo. Do not free this value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * This augments ::OrtTypeInfo to return an ::OrtMapTypeInfo when the type is a map. + * The OrtMapTypeInfo has additional information about the map's key type and value type. + * + * This is used by WinML to support model reflection APIs. + * + * \param[out] type_info + * \param[out] out A pointer to the ::OrtMapTypeInfo. Do not free this value. If type_info + * does not contain a map, this value will be set to nullptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, _Outptr_result_maybenull_ const OrtMapTypeInfo** out); /** \brief Cast ::OrtTypeInfo to an ::OrtSequenceTypeInfo - * - * This api augments ::OrtTypeInfo to return an ::OrtSequenceTypeInfo when the type is a sequence. - * The ::OrtSequenceTypeInfo has additional information about the sequence's element type. - * - * This is used by WinML to support model reflection APIs. - * - * \param[in] type_info - * \param[out] out A pointer to the OrtSequenceTypeInfo. Do not free this value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * This api augments ::OrtTypeInfo to return an ::OrtSequenceTypeInfo when the type is a sequence. + * The ::OrtSequenceTypeInfo has additional information about the sequence's element type. + * + * This is used by WinML to support model reflection APIs. + * + * \param[in] type_info + * \param[out] out A pointer to the OrtSequenceTypeInfo. Do not free this value. If type_info + * doesn not contain a sequence, this value will be set to nullptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out); @@ -1773,25 +1977,25 @@ struct OrtApi { /// @{ /** \brief Get key type from an ::OrtMapTypeInfo - * - * Key types are restricted to being scalar types. - * - * This is used by WinML to support model reflection APIs. - * - * \param[in] map_type_info - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Key types are restricted to being scalar types. + * + * This is used by WinML to support model reflection APIs. + * + * \param[in] map_type_info + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out); /** \brief Get the value type from an ::OrtMapTypeInfo - * - * \param[in] map_type_info - * \param[out] type_info - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] map_type_info + * \param[out] type_info + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info); /// @} @@ -1799,14 +2003,14 @@ struct OrtApi { /// @{ /** \brief Get element type from an ::OrtSequenceTypeInfo - * - * This is used by WinML to support model reflection APIs. - * - * \param[in] sequence_type_info - * \param[out] type_info - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * This is used by WinML to support model reflection APIs. + * + * \param[in] sequence_type_info + * \param[out] type_info + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, _Outptr_ OrtTypeInfo** type_info); @@ -1824,24 +2028,24 @@ struct OrtApi { /// @{ /** \brief End profiling and return filename of the profile data - * - * Profiling is turned on through OrtApi::EnableProfiling - * - * \param[in] session - * \param[in] allocator - * \param[out] out Null terminated string of the filename, allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Profiling is turned on through OrtApi::EnableProfiling + * + * \param[in] session + * \param[in] allocator + * \param[out] out Null terminated string of the filename, allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionEndProfiling, _In_ OrtSession* session, _Inout_ OrtAllocator* allocator, _Outptr_ char** out); /** \brief Get ::OrtModelMetadata from an ::OrtSession - * - * \param[in] session - * \param[out] out Newly created ::OrtModelMetadata. Must be freed using OrtApi::ReleaseModelMetadata - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] session + * \param[out] out Newly created ::OrtModelMetadata. Must be freed using OrtApi::ReleaseModelMetadata + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionGetModelMetadata, _In_ const OrtSession* session, _Outptr_ OrtModelMetadata** out); /// @} @@ -1849,69 +2053,69 @@ struct OrtApi { /// @{ /** \brief Get `producer name` from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(ModelMetadataGetProducerName, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); /** \brief Get `graph name` from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(ModelMetadataGetGraphName, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); /** \brief Get `domain` from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); /** \brief Get `description` from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); /** \brief Return data for a key in the custom metadata map in an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[in] key Null terminated string - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * `value` will be set to nullptr if the given key is not found in the custom metadata map. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] model_metadata + * \param[in] allocator + * \param[in] key Null terminated string + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * `value` will be set to nullptr if the given key is not found in the custom metadata map. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value); /** \brief Get version number from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[out] value Set to the version number - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] model_metadata + * \param[out] value Set to the version number + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(ModelMetadataGetVersion, _In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value); ORT_CLASS_RELEASE(ModelMetadata); @@ -1921,18 +2125,18 @@ struct OrtApi { /// @{ /** \brief Create an OrtEnv - * - * Create an environment with global threadpools that will be shared across sessions. - * Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use - * its own thread pools. - * - * \param[in] log_severity_level The log severity level. - * \param[in] logid The log identifier. - * \param[in] tp_options - * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Create an environment with global threadpools that will be shared across sessions. + * Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use + * its own thread pools. + * + * \param[in] log_severity_level The log severity level. + * \param[in] logid The log identifier. + * \param[in] tp_options + * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateEnvWithGlobalThreadPools, OrtLoggingLevel log_severity_level, _In_ const char* logid, _In_ const OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out); @@ -1941,14 +2145,14 @@ struct OrtApi { /// @{ /** \brief Use global thread pool on a session - * - * Disable using per session thread pool and use the shared global threadpool. - * This should be used in conjunction with OrtApi::CreateEnvWithGlobalThreadPools. - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Disable using per session thread pool and use the shared global threadpool. + * This should be used in conjunction with OrtApi::CreateEnvWithGlobalThreadPools. + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(DisablePerSessionThreads, _Inout_ OrtSessionOptions* options); /// @} @@ -1956,10 +2160,10 @@ struct OrtApi { /// @{ /** \brief Create an ::OrtThreadingOptions - * - * \param[out] out Newly created ::OrtThreadingOptions. Must be freed with OrtApi::ReleaseThreadingOptions - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[out] out Newly created ::OrtThreadingOptions. Must be freed with OrtApi::ReleaseThreadingOptions + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out); ORT_CLASS_RELEASE(ThreadingOptions); @@ -1969,16 +2173,16 @@ struct OrtApi { /// @{ /** - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] keys Array of null terminated strings (array count = num_keys) allocated using `allocator`. - * The strings and the pointer array must be freed using `allocator` - * `keys` will be set to nullptr if the custom metadata map is empty. - * \param[out] num_keys Set to the number of elements in the `keys` array - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] keys Array of null terminated strings (array count = num_keys) allocated using `allocator`. + * The strings and the pointer array must be freed using `allocator` + * `keys` will be set to nullptr if the custom metadata map is empty. + * \param[out] num_keys Set to the number of elements in the `keys` array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*num_keys) char*** keys, _Out_ int64_t* num_keys); @@ -1987,12 +2191,12 @@ struct OrtApi { /// @{ /** - * - * Override symbolic dimensions (by specific name strings) with actual values - * if known at session initialization time to enable optimizations that can - * take advantage of fixed values (such as memory planning, etc) - * - */ + * + * Override symbolic dimensions (by specific name strings) with actual values + * if known at session initialization time to enable optimizations that can + * take advantage of fixed values (such as memory planning, etc) + * + */ ORT_API2_STATUS(AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, _In_ int64_t dim_value); @@ -2002,25 +2206,26 @@ struct OrtApi { /// @{ /** \brief Get the names of all available providers - * - * \note The providers in the list are not guaranteed to be usable. They may fail to load due to missing system dependencies. - * For example, if the CUDA/cuDNN libraries are not installed, the CUDA provider will report an error when it is added to the session options. - * - * \param[out] out_ptr Set to a pointer to an array of null terminated strings of the available providers. The entries and the - * array itself must be freed using OrtApi::ReleaseAvailableProviders - * \param[out] provider_length Set to the number of entries in the `out_ptr` array - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \note The providers in the list are not guaranteed to be usable. They may fail to load due to missing system dependencies. + * For example, if the CUDA/cuDNN libraries are not installed, the CUDA provider will report an error when it is added to the session options. + * + * \param[out] out_ptr Set to a pointer to an array of null terminated strings of the available providers. The entries and the + * array itself must be freed using OrtApi::ReleaseAvailableProviders + * \param[out] provider_length Set to the number of entries in the `out_ptr` array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetAvailableProviders, _Outptr_ char*** out_ptr, _Out_ int* provider_length); - /** \brief Release data from OrtApi::GetAvailableProviders - * - * \param[in] ptr The `out_ptr` result from OrtApi::GetAvailableProviders. - * \param[in] providers_length The `provider_length` result from OrtApi::GetAvailableProviders - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + /** \brief Release data from OrtApi::GetAvailableProviders. This API will never fail + * so you can rely on it in a noexcept code. + * + * \param[in] ptr The `out_ptr` result from OrtApi::GetAvailableProviders. + * \param[in] providers_length The `provider_length` result from OrtApi::GetAvailableProviders + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(ReleaseAvailableProviders, _In_ char** ptr, _In_ int providers_length); @@ -2029,34 +2234,34 @@ struct OrtApi { /// @{ /** \brief Get the length of a single string in a string tensor - * - * \param[in] value A string tensor - * \param[in] index Index of the string in the tensor - * \param[out] out Set to number of bytes of the string element - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] value A string tensor + * \param[in] index Index of the string in the tensor + * \param[out] out Set to number of bytes of the string element + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetStringTensorElementLength, _In_ const OrtValue* value, size_t index, _Out_ size_t* out); /** \brief Get a single string from a string tensor - * - * \param[in] value A string tensor - * \param[in] s_len Number of bytes in the `s` buffer. Must match the value returned by OrtApi::GetStringTensorElementLength. - * \param[in] index Index of the string in the tensor - * \param[out] s The string element contents in UTF-8 encoding. The string is NOT null-terminated. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] value A string tensor + * \param[in] s_len Number of bytes in the `s` buffer. Must match the value returned by OrtApi::GetStringTensorElementLength. + * \param[in] index Index of the string in the tensor + * \param[out] s The string element contents in UTF-8 encoding. The string is NOT null-terminated. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetStringTensorElement, _In_ const OrtValue* value, size_t s_len, size_t index, _Out_writes_bytes_all_(s_len) void* s); /** \brief Set a single string in a string tensor - * - * \param[in] value A string tensor - * \param[in] s A null terminated UTF-8 encoded string - * \param[in] index Index of the string in the tensor to set - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] value A string tensor + * \param[in] s A null terminated UTF-8 encoded string + * \param[in] index Index of the string in the tensor to set + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(FillStringTensorElement, _Inout_ OrtValue* value, _In_ const char* s, size_t index); /// @} @@ -2064,17 +2269,17 @@ struct OrtApi { /// @{ /** \brief Set a session configuration entry as a pair of strings - * - * If a configuration with same key exists, this will overwrite the configuration with the given config_value. - * - * The config_key and the format of config_value are defined in onnxruntime_session_options_config_keys.h - * - * \param[in] options - * \param[in] config_key A null terminated string representation of the config key - * \param[in] config_value A null terminated string representation of the config value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If a configuration with same key exists, this will overwrite the configuration with the given config_value. + * + * The config_key and the format of config_value are defined in onnxruntime_session_options_config_keys.h + * + * \param[in] options + * \param[in] config_key A null terminated string representation of the config key + * \param[in] config_value A null terminated string representation of the config value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options, _In_z_ const char* config_key, _In_z_ const char* config_value); @@ -2083,18 +2288,18 @@ struct OrtApi { /// @{ /** \brief Create an allocator for an ::OrtSession following an ::OrtMemoryInfo - * - * \param[in] session - * \param[in] mem_info valid ::OrtMemoryInfo instance - * \param[out] out Newly created ::OrtAllocator. Must be freed with OrtApi::ReleaseAllocator - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] session + * \param[in] mem_info valid ::OrtMemoryInfo instance + * \param[out] out Newly created ::OrtAllocator. Must be freed with OrtApi::ReleaseAllocator + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateAllocator, _In_ const OrtSession* session, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out); /** \brief Release an ::OrtAllocator obtained from OrtApi::CreateAllocator - */ + */ ORT_CLASS_RELEASE(Allocator); /// @} @@ -2102,28 +2307,28 @@ struct OrtApi { /// @{ /** \brief Run a model using Io Bindings for the inputs & outputs - * - * \see OrtApi::Run - * - * \param[in] session - * \param[in] run_options - * \param[in] binding_ptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \see OrtApi::Run + * + * \param[in] session + * \param[in] run_options + * \param[in] binding_ptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(RunWithBinding, _Inout_ OrtSession* session, _In_ const OrtRunOptions* run_options, _In_ const OrtIoBinding* binding_ptr); /** \brief Create an ::OrtIoBinding instance - * - * An IoBinding object allows one to bind pre-allocated ::OrtValue%s to input names. - * Thus if you want to use a raw on device buffer as input or output you can avoid - * extra copy during runtime. - * - * \param[in] session - * \param[out] out Newly created ::OrtIoBinding. Must be freed with OrtApi::ReleaseIoBinding - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * An IoBinding object allows one to bind pre-allocated ::OrtValue%s to input names. + * Thus if you want to use a raw on device buffer as input or output you can avoid + * extra copy during runtime. + * + * \param[in] session + * \param[out] out Newly created ::OrtIoBinding. Must be freed with OrtApi::ReleaseIoBinding + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateIoBinding, _Inout_ OrtSession* session, _Outptr_ OrtIoBinding** out); /// @} @@ -2131,87 +2336,87 @@ struct OrtApi { /// @{ /** \brief Release an ::OrtIoBinding obtained from OrtApi::CreateIoBinding - */ + */ ORT_CLASS_RELEASE(IoBinding); /** \brief Bind an ::OrtValue to an ::OrtIoBinding input - * - * When using OrtApi::RunWithBinding this value is used for the named input - * - * \param[in] binding_ptr - * \param[in] name Name for the model input - * \param[in] val_ptr ::OrtValue of Tensor type. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * When using OrtApi::RunWithBinding this value is used for the named input + * + * \param[in] binding_ptr + * \param[in] name Name for the model input + * \param[in] val_ptr ::OrtValue of Tensor type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(BindInput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr); /** \brief Bind an ::OrtValue to an ::OrtIoBinding output - * - * When using OrtApi::RunWithBinding this value is used for the named output - * - * \param[in] binding_ptr - * \param[in] name Null terminated string of the model output name - * \param[in] val_ptr ::OrtValue of Tensor type. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * When using OrtApi::RunWithBinding this value is used for the named output + * + * \param[in] binding_ptr + * \param[in] name Null terminated string of the model output name + * \param[in] val_ptr ::OrtValue of Tensor type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(BindOutput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr); /** \brief Bind an ::OrtIoBinding output to a device - * - * Binds the ::OrtValue to a device which is specified by ::OrtMemoryInfo. - * You can either create an instance of ::OrtMemoryInfo with a device id or obtain one from the allocator that you have created/are using - * This is useful when one or more outputs have dynamic shapes and, it is hard to pre-allocate and bind a chunk of - * memory within ::OrtValue ahead of time. - * - * \see OrtApi::RunWithBinding - * - * \param[in] binding_ptr - * \param[in] name Null terminated string of the device name - * \param[in] mem_info_ptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Binds the ::OrtValue to a device which is specified by ::OrtMemoryInfo. + * You can either create an instance of ::OrtMemoryInfo with a device id or obtain one from the allocator that you have created/are using + * This is useful when one or more outputs have dynamic shapes and, it is hard to pre-allocate and bind a chunk of + * memory within ::OrtValue ahead of time. + * + * \see OrtApi::RunWithBinding + * + * \param[in] binding_ptr + * \param[in] name Null terminated string of the device name + * \param[in] mem_info_ptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(BindOutputToDevice, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtMemoryInfo* mem_info_ptr); /** \brief Get the names of an ::OrtIoBinding's outputs - * - * Returns the names of the outputs in the order they were bound. This is useful after running the model - * with bound outputs because the returned names are in order in which output ::OrtValue are returned. This is useful if - * the order of outputs and their names is not known. - * - * \param[in] binding_ptr - * \param[in] allocator Allocator used to allocate continuous buffers for output strings and lengths. - * \param[out] buffer Returns an array of non-null terminated UTF-8 strings. The number of strings stored is returned in the count parameter. - * This buffer is allocated using `allocator` and must be freed using it. - * \param[out] lengths Returns an array of `count` lengths of the strings returned in `buffer` - * This buffer is allocated using `allocator` and must be freed using it. - * \param[out] count Number of strings returned. If `binding_ptr` has no bound outputs, zero is returned, - * no memory allocation is performed and buffer and lengths are set to nullptr. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Returns the names of the outputs in the order they were bound. This is useful after running the model + * with bound outputs because the returned names are in order in which output ::OrtValue are returned. This is useful if + * the order of outputs and their names is not known. + * + * \param[in] binding_ptr + * \param[in] allocator Allocator used to allocate continuous buffers for output strings and lengths. + * \param[out] buffer Returns an array of non-null terminated UTF-8 strings. The number of strings stored is returned in the count parameter. + * This buffer is allocated using `allocator` and must be freed using it. + * \param[out] lengths Returns an array of `count` lengths of the strings returned in `buffer` + * This buffer is allocated using `allocator` and must be freed using it. + * \param[out] count Number of strings returned. If `binding_ptr` has no bound outputs, zero is returned, + * no memory allocation is performed and buffer and lengths are set to nullptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetBoundOutputNames, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator, _Out_ char** buffer, _Out_writes_all_(count) size_t** lengths, _Out_ size_t* count); /** \brief Get the output ::OrtValue objects from an ::OrtIoBinding - * - * Returns an array of pointers to individually allocated ::OrtValue%s that contain results of a model execution with OrtApi::RunWithBinding - * The array contains the same number of ::OrtValue%s and they are in the same order as they were bound with OrtApi::BindOutput - * or OrtApi::BindOutputToDevice. - * - * The returned ::OrtValue%s must be released using OrtApi::ReleaseValue after they are no longer needed. - * The array is allocated using the specified instance of the allocator and must be freed using the same allocator after - * all the ::OrtValue%s contained therein are individually released. - * - * \param[in] binding_ptr - * \param[in] allocator Allocator used to allocate output array - * \param[out] output Set to the allocated array of allocated ::OrtValue outputs. Set to nullptr if there are 0 outputs. - * \param[out] output_count Set to number of ::OrtValue%s returned - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Returns an array of pointers to individually allocated ::OrtValue%s that contain results of a model execution with OrtApi::RunWithBinding + * The array contains the same number of ::OrtValue%s and they are in the same order as they were bound with OrtApi::BindOutput + * or OrtApi::BindOutputToDevice. + * + * The returned ::OrtValue%s must be released using OrtApi::ReleaseValue after they are no longer needed. + * The array is allocated using the specified instance of the allocator and must be freed using the same allocator after + * all the ::OrtValue%s contained therein are individually released. + * + * \param[in] binding_ptr + * \param[in] allocator Allocator used to allocate output array + * \param[out] output Set to the allocated array of allocated ::OrtValue outputs. Set to nullptr if there are 0 outputs. + * \param[out] output_count Set to number of ::OrtValue%s returned + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetBoundOutputValues, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator, _Out_writes_all_(output_count) OrtValue*** output, _Out_ size_t* output_count); @@ -2228,19 +2433,19 @@ struct OrtApi { /// @{ /** \brief Direct memory access to a specified tensor element - * - * For example, given a tensor with shape of [3,224,224], a pointer to the element at location [2,150,128] can be retrieved - * - * This function only works for numeric type tensors (No strings, etc). - * This is a no-copy method whose returned pointer is valid until the passed in ::OrtValue is free'd. - * - * \param[in] value - * \param[in] location_values Pointer to an array of index values that specify an element's location relative to its shape - * \param[in] location_values_count Number of elements in location_values. Must match the number of elements in the tensor's shape. - * \param[out] out Set to a pointer to the element specified - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * For example, given a tensor with shape of [3,224,224], a pointer to the element at location [2,150,128] can be retrieved + * + * This function only works for numeric type tensors (No strings, etc). + * This is a no-copy method whose returned pointer is valid until the passed in ::OrtValue is free'd. + * + * \param[in] value + * \param[in] location_values Pointer to an array of index values that specify an element's location relative to its shape + * \param[in] location_values_count Number of elements in location_values. Must match the number of elements in the tensor's shape. + * \param[out] out Set to a pointer to the element specified + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out); /// @} @@ -2248,33 +2453,33 @@ struct OrtApi { /// @{ /** \brief Create an allocator and register it with the ::OrtEnv - * - * Enables sharing the allocator between multiple sessions that use the same env instance. - * Lifetime of the created allocator will be valid for the duration of the environment. - * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. - * - * See https://onnxruntime.ai/docs/reference/api/c-api.html for details. - * - * \param[in] env ::OrtEnv instance - * \param[in] mem_info - * \param[in] arena_cfg Pass nullptr for defaults - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Enables sharing the allocator between multiple sessions that use the same env instance. + * Lifetime of the created allocator will be valid for the duration of the environment. + * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. + * + * See https://onnxruntime.ai/docs/get-started/with-c.html for details. + * + * \param[in] env ::OrtEnv instance + * \param[in] mem_info + * \param[in] arena_cfg Pass nullptr for defaults + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg); /** \brief Set language projection - * - * Set the language projection for collecting telemetry data when Env is created. - * - * The default is ORT_PROJECTION_C, which means it will classify the language not in the list to C also. - * - * \param[in] ort_env - * \param[in] projection - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Set the language projection for collecting telemetry data when Env is created. + * + * The default is ORT_PROJECTION_C, which means it will classify the language not in the list to C also. + * + * \param[in] ort_env + * \param[in] projection + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection); /// @} @@ -2282,14 +2487,14 @@ struct OrtApi { /// @{ /** \brief Return the time that profiling was started - * - * \note The timer precision varies per platform. On Windows and MacOS, the precision will be ~100ns - * - * \param[in] session - * \param[out] out nanoseconds of profiling's start time - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \note The timer precision varies per platform. On Windows and MacOS, the precision will be ~100ns + * + * \param[in] session + * \param[out] out nanoseconds of profiling's start time + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionGetProfilingStartTimeNs, _In_ const OrtSession* session, _Outptr_ uint64_t* out); /// @} @@ -2297,44 +2502,44 @@ struct OrtApi { /// @{ /** \brief Set global intra-op thread count - * - * This configures the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools - * - * \param[in] tp_options - * \param[in] intra_op_num_threads Number of threads, special values:
- * 0 = Use default thread count
- * 1 = The invoking thread will be used; no threads will be created in the thread pool. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * This configures the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools + * + * \param[in] tp_options + * \param[in] intra_op_num_threads Number of threads, special values:
+ * 0 = Use default thread count
+ * 1 = The invoking thread will be used; no threads will be created in the thread pool. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int intra_op_num_threads); /** \brief Set global inter-op thread count - * - * This configures the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools - * - * \param[in] tp_options - * \param[in] inter_op_num_threads Number of threads, special values:
- * 0 = Use default thread count
- * 1 = The invoking thread will be used; no threads will be created in the thread pool. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * This configures the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools + * + * \param[in] tp_options + * \param[in] inter_op_num_threads Number of threads, special values:
+ * 0 = Use default thread count
+ * 1 = The invoking thread will be used; no threads will be created in the thread pool. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int inter_op_num_threads); /** \brief Set global spin control options - * - * This will configure the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools. - * Allow spinning of thread pools when their queues are empty. This will set the value for both - * inter_op and intra_op threadpools. - * - * \param[in] tp_options - * \param[in] allow_spinning Valid values are 0 or 1.
- * 0 = It won't spin (recommended if CPU usage is high)
- * 1 = Threadpool will spin to wait for queue to become non-empty - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * This will configure the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools. + * Allow spinning of thread pools when their queues are empty. This will set the value for both + * inter_op and intra_op threadpools. + * + * \param[in] tp_options + * \param[in] allow_spinning Valid values are 0 or 1.
+ * 0 = It won't spin (recommended if CPU usage is high)
+ * 1 = Threadpool will spin to wait for queue to become non-empty + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning); /// @} @@ -2342,19 +2547,19 @@ struct OrtApi { /// @{ /** \brief Add a pre-allocated initializer to a session - * - * If a model contains an initializer with a name that is same as the name passed to this call, - * ORT will use this initializer instance instead of deserializing one from the model file. This - * is useful when you want to share the same initializer across sessions. - * - * \param[in] options - * \param[in] name Null terminated string of the initializer name - * \param[in] val ::OrtValue containing the initializer. Its lifetime and the underlying initializer buffer must be - * managed by the user (created using the OrtApi::CreateTensorWithDataAsOrtValue) and it must outlive the session object - * to which it is added. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If a model contains an initializer with a name that is same as the name passed to this call, + * ORT will use this initializer instance instead of deserializing one from the model file. This + * is useful when you want to share the same initializer across sessions. + * + * \param[in] options + * \param[in] name Null terminated string of the initializer name + * \param[in] val ::OrtValue containing the initializer. Its lifetime and the underlying initializer buffer must be + * managed by the user (created using the OrtApi::CreateTensorWithDataAsOrtValue) and it must outlive the session object + * to which it is added. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ const char* name, _In_ const OrtValue* val); @@ -2363,20 +2568,20 @@ struct OrtApi { /// @{ /** - * Create a custom environment with global threadpools and logger that will be shared across sessions. - * Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use - * its own thread pools. - * - * \param[in] logging_function A pointer to a logging function. - * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to - * `logging_function`. - * \param[in] log_severity_level The log severity level. - * \param[in] logid The log identifier. - * \param[in] tp_options - * \param[out] out Newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * Create a custom environment with global threadpools and logger that will be shared across sessions. + * Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use + * its own thread pools. + * + * \param[in] logging_function A pointer to a logging function. + * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to + * `logging_function`. + * \param[in] log_severity_level The log severity level. + * \param[in] logid The log identifier. + * \param[in] tp_options + * \param[out] out Newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateEnvWithCustomLoggerAndGlobalThreadPools, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel log_severity_level, _In_ const char* logid, _In_ const struct OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out); @@ -2385,38 +2590,38 @@ struct OrtApi { /// @{ /** \brief Append CUDA provider to session options - * - * If CUDA is not available (due to a non CUDA enabled build, or if CUDA is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] cuda_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If CUDA is not available (due to a non CUDA enabled build, or if CUDA is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] cuda_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options); /** \brief Append ROCM execution provider to the session options - * - * If ROCM is not available (due to a non ROCM enabled build, or if ROCM is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] rocm_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If ROCM is not available (due to a non ROCM enabled build, or if ROCM is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] rocm_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options); /** \brief Append OpenVINO execution provider to the session options - * - * If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail. - * - * \param[in] options - * \param[in] provider_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail. + * + * \param[in] options + * \param[in] provider_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options); @@ -2425,15 +2630,15 @@ struct OrtApi { /// @{ /** \brief Set threading flush-to-zero and denormal-as-zero - * - * Sets global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools. - * Flush-to-zero and denormal-as-zero are applied to threads in both intra and inter global thread pool. - * \note This option is not needed if the models used have no denormals. Having no denormals is recommended as this option may hurt model accuracy. - * - * \param[in] tp_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Sets global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools. + * Flush-to-zero and denormal-as-zero are applied to threads in both intra and inter global thread pool. + * \note This option is not needed if the models used have no denormals. Having no denormals is recommended as this option may hurt model accuracy. + * + * \param[in] tp_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* tp_options); /// @} @@ -2441,17 +2646,17 @@ struct OrtApi { /// @{ /** \deprecated Use OrtApi::CreateArenaCfgV2 - * - * This will create the configuration of an arena that can eventually be used to define an arena based allocator's behavior - * - * \param[in] max_mem Use 0 to allow ORT to choose the default - * \param[in] arena_extend_strategy Use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested - * \param[in] initial_chunk_size_bytes Use -1 to allow ORT to choose the default - * \param[in] max_dead_bytes_per_chunk Use -1 to allow ORT to choose the default - * \param[in] out A pointer to an OrtArenaCfg instance - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * This will create the configuration of an arena that can eventually be used to define an arena based allocator's behavior + * + * \param[in] max_mem Use 0 to allow ORT to choose the default + * \param[in] arena_extend_strategy Use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested + * \param[in] initial_chunk_size_bytes Use -1 to allow ORT to choose the default + * \param[in] max_dead_bytes_per_chunk Use -1 to allow ORT to choose the default + * \param[in] out A pointer to an OrtArenaCfg instance + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateArenaCfg, _In_ size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk, _Outptr_ OrtArenaCfg** out); @@ -2462,16 +2667,16 @@ struct OrtApi { /// @{ /** - * Use this to obtain the description of the graph present in the model - * (doc_string field of the GraphProto message within the ModelProto message). - * If it doesn't exist, an empty string will be returned. - * - * \param[in] model_metadata An instance of ::OrtModelMetadata - * \param[in] allocator Allocator used to allocate the string that will be returned back - * \param[out] value Set to a null terminated string allocated using `allocator`. The caller is responsible for freeing it using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * Use this to obtain the description of the graph present in the model + * (doc_string field of the GraphProto message within the ModelProto message). + * If it doesn't exist, an empty string will be returned. + * + * \param[in] model_metadata An instance of ::OrtModelMetadata + * \param[in] allocator Allocator used to allocate the string that will be returned back + * \param[out] value Set to a null terminated string allocated using `allocator`. The caller is responsible for freeing it using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(ModelMetadataGetGraphDescription, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); @@ -2480,14 +2685,14 @@ struct OrtApi { /// @{ /** \brief Append TensorRT provider to session options - * - * If TensorRT is not available (due to a non TensorRT enabled build, or if TensorRT is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] tensorrt_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If TensorRT is not available (due to a non TensorRT enabled build, or if TensorRT is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] tensorrt_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options); @@ -2496,113 +2701,118 @@ struct OrtApi { /// @{ /** \brief Set current GPU device ID - * - * Set the current device id of the GPU execution provider (CUDA/tensorrt/rocm). The device id should be less - * than the total number of devices available. This is only useful when multiple-GPUs are installed and it is - * required to restrict execution to a single GPU. - * - * \param[in] device_id - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Set the current device id of the GPU execution provider (CUDA/tensorrt/rocm). The device id should be less + * than the total number of devices available. This is only useful when multiple-GPUs are installed and it is + * required to restrict execution to a single GPU. + * + * \param[in] device_id + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetCurrentGpuDeviceId, _In_ int device_id); /** \brief Get current GPU device ID - * - * Get the current device id of the GPU execution provider (CUDA/tensorrt/rocm). - * - * \see OrtApi::SetCurrentGpuDeviceId - * - * \param[out] device_id - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Get the current device id of the GPU execution provider (CUDA/tensorrt/rocm). + * + * \see OrtApi::SetCurrentGpuDeviceId + * + * \param[out] device_id + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetCurrentGpuDeviceId, _In_ int* device_id); /// @} /// \name OrtKernelInfo + /// Custom operator APIs. /// @{ /** \brief Fetch an array of int64_t values stored as an attribute in the graph node - * - * - * If `out` is nullptr, the value of `size` is set to the true size of the attribute - * array's size, and a success status is returned. - * - * If the `size` parameter is greater than or equal to the actual attribute array's size, - * the value of `size` is set to the true size of the attribute array's size, - * the provided memory is filled with the attribute's contents, - * and a success status is returned. - * - * If the `size` parameter is less than the actual attribute array's size and `out` - * is not nullptr, the value of `size` is set to the true size of the attribute array's size - * and a failure status is returned.) - * - * \param[in] info instance - * \param[in] name name of the attribute to be parsed - * \param[out] out pointer to memory where the attribute's contents are to be stored - * \param[in, out] size actual size of attribute array - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * + * If `out` is nullptr, the value of `size` is set to the true size of the attribute + * array's size, and a success status is returned. + * + * If the `size` parameter is greater than or equal to the actual attribute array's size, + * the value of `size` is set to the true size of the attribute array's size, + * the provided memory is filled with the attribute's contents, + * and a success status is returned. + * + * If the `size` parameter is less than the actual attribute array's size and `out` + * is not nullptr, the value of `size` is set to the true size of the attribute array's size + * and a failure status is returned.) + * + * \param[in] info instance + * \param[in] name name of the attribute to be parsed + * \param[out] out pointer to memory where the attribute's contents are to be stored + * \param[in, out] size actual size of attribute array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out, _Inout_ size_t* size); /** \brief Fetch an array of int64_t values stored as an attribute in the graph node - * - * If `out` is nullptr, the value of `size` is set to the true size of the attribute - * array's size, and a success status is returned. - * - * If the `size` parameter is greater than or equal to the actual attribute array's size, - * the value of `size` is set to the true size of the attribute array's size, - * the provided memory is filled with the attribute's contents, - * and a success status is returned. - * - * If the `size` parameter is less than the actual attribute array's size and `out` - * is not nullptr, the value of `size` is set to the true size of the attribute array's size - * and a failure status is returned.) - * - * \param[in] info instance - * \param[in] name name of the attribute to be parsed - * \param[out] out pointer to memory where the attribute's contents are to be stored - * \param[in, out] size actual size of attribute array - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, - _Out_ int64_t* out, _Inout_ size_t* size); - - /// @} - /// \name OrtArenaCfg + * + * If `out` is nullptr, the value of `size` is set to the true size of the attribute + * array's size, and a success status is returned. + * + * If the `size` parameter is greater than or equal to the actual attribute array's size, + * the value of `size` is set to the true size of the attribute array's size, + * the provided memory is filled with the attribute's contents, + * and a success status is returned. + * + * If the `size` parameter is less than the actual attribute array's size and `out` + * is not nullptr, the value of `size` is set to the true size of the attribute array's size + * and a failure status is returned.) + * + * \param[in] info instance + * \param[in] name name of the attribute to be parsed + * \param[out] out pointer to memory where the attribute's contents are to be stored + * \param[in, out] size actual size of attribute array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ int64_t* out, _Inout_ size_t* size); + + /// @} + /// \name OrtArenaCfg /// @{ /** \brief Create an ::OrtArenaCfg - * - * Create the configuration of an arena that can eventually be used to define an arena based allocator's behavior. - * - * Supported keys are (See https://onnxruntime.ai/docs/reference/api/c-api.html for details on what the - * following parameters mean and how to choose these values.): - * "max_mem": Maximum memory that can be allocated by the arena based allocator. - * Use 0 for ORT to pick the best value. Default is 0. - * "arena_extend_strategy": 0 = kNextPowerOfTwo, 1 = kSameAsRequested. - * Use -1 to allow ORT to choose the default. - * "initial_chunk_size_bytes": (Possible) Size of the first allocation in the arena. - * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. - * Ultimately, the first allocation size is determined by the allocation memory request. - * "max_dead_bytes_per_chunk": Threshold of unused memory in an allocated chunk of arena memory after - * crossing which the current chunk is chunked into 2. - * "initial_growth_chunk_size_bytes": (Possible) Size of the second allocation in the arena. - * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. - * Ultimately, the allocation size is determined by the allocation memory request. - * Further allocation sizes are governed by the arena extend strategy. - * - * \param[in] arena_config_keys Keys to configure the arena - * \param[in] arena_config_values Values to configure the arena - * \param[in] num_keys Number of keys in `arena_config_keys` and `arena_config_values` - * \param[out] out Newly created ::OrtArenaCfg. Must be freed with OrtApi::ReleaseArenaCfg - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Create the configuration of an arena that can eventually be used to define an arena based allocator's behavior. + * + * Supported keys are (See https://onnxruntime.ai/docs/get-started/with-c.html for details on what the + * following parameters mean and how to choose these values.): + * "max_mem": Maximum memory that can be allocated by the arena based allocator. + * Use 0 for ORT to pick the best value. Default is 0. + * "arena_extend_strategy": 0 = kNextPowerOfTwo, 1 = kSameAsRequested. + * Use -1 to allow ORT to choose the default. + * "initial_chunk_size_bytes": (Possible) Size of the first allocation in the arena. + * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. + * Ultimately, the first allocation size is determined by the allocation memory request. + * "max_dead_bytes_per_chunk": Threshold of unused memory in an allocated chunk of arena memory after + * crossing which the current chunk is chunked into 2. + * "initial_growth_chunk_size_bytes": (Possible) Size of the second allocation in the arena. + * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. + * "max_power_of_two_extend_bytes": The maximum enxtend size if arena strategy is `kNextPowerOfTwo`. + * It is not an allocation limit, it is only a limit for extension when requested byte is less than the limit. + * When requested bytes is more than the limit, allocator will still return as requested. + * Use -1 to allow ORT to choose the default 1GB for max_power_of_two_extend_bytes. + * Ultimately, the allocation size is determined by the allocation memory request. + * Further allocation sizes are governed by the arena extend strategy. + * + * \param[in] arena_config_keys Keys to configure the arena + * \param[in] arena_config_values Values to configure the arena + * \param[in] num_keys Number of keys in `arena_config_keys` and `arena_config_values` + * \param[out] out Newly created ::OrtArenaCfg. Must be freed with OrtApi::ReleaseArenaCfg + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateArenaCfgV2, _In_reads_(num_keys) const char* const* arena_config_keys, _In_reads_(num_keys) const size_t* arena_config_values, _In_ size_t num_keys, _Outptr_ OrtArenaCfg** out); @@ -2612,17 +2822,17 @@ struct OrtApi { /// @{ /** \brief Set a single run configuration entry as a pair of strings - * - * If a configuration with same key exists, this will overwrite the configuration with the given config_value - * - * The config_key and the format of config_value are defined in onnxruntime_run_options_config_keys.h - * - * \param[in] options - * \param[in] config_key A null terminated string representation of the config key - * \param[in] config_value A null terminated string representation of the config value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If a configuration with same key exists, this will overwrite the configuration with the given config_value + * + * The config_key and the format of config_value are defined in onnxruntime_run_options_config_keys.h + * + * \param[in] options + * \param[in] config_key A null terminated string representation of the config key + * \param[in] config_value A null terminated string representation of the config value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(AddRunConfigEntry, _Inout_ OrtRunOptions* options, _In_z_ const char* config_key, _In_z_ const char* config_value); @@ -2631,23 +2841,23 @@ struct OrtApi { /// @{ /** \brief Create an ::OrtPrepackedWeightsContainer - * - * This container will hold pre-packed buffers of shared initializers for sharing between sessions - * (i.e.) if there are shared initializers that can be shared between sessions, the pre-packed buffers - * of these (if any) may possibly be shared to provide memory footprint savings. Pass this container - * to sessions that you would like to share pre-packed buffers of shared initializers at session - * creation time. - * - * \param[out] out Newly created ::OrtPrepackedWeightsContainer. Must be freed with OrtApi::ReleasePrepackedWeightsContainer - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * This container will hold pre-packed buffers of shared initializers for sharing between sessions + * (i.e.) if there are shared initializers that can be shared between sessions, the pre-packed buffers + * of these (if any) may possibly be shared to provide memory footprint savings. Pass this container + * to sessions that you would like to share pre-packed buffers of shared initializers at session + * creation time. + * + * \param[out] out Newly created ::OrtPrepackedWeightsContainer. Must be freed with OrtApi::ReleasePrepackedWeightsContainer + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreatePrepackedWeightsContainer, _Outptr_ OrtPrepackedWeightsContainer** out); /** \brief Release OrtPrepackedWeightsContainer instance - * - * \note instance must not be released until the sessions using it are released - */ + * + * \note instance must not be released until the sessions using it are released + */ ORT_CLASS_RELEASE(PrepackedWeightsContainer); /// @} @@ -2655,44 +2865,44 @@ struct OrtApi { /// @{ /** \brief Create session with prepacked weights container - * - * Same functionality offered by OrtApi::CreateSession except that a container that contains - * pre-packed weights' buffers is written into/read from by the created session. - * This is useful when used in conjunction with OrtApi::AddInitializer which injects - * shared initializer info into sessions. Wherever possible, the pre-packed versions of these - * shared initializers are cached in this container so that multiple sessions can just re-use - * these instead of duplicating these in memory. - * - * \param[in] env OrtEnv instance instance - * \param[in] model_path Null terminated string of the path (wchar on Windows, char otherwise) - * \param[in] options - * \param[in] prepacked_weights_container - * \param[out] out Newly created ::OrtSession. Must be freed with OrtApi::ReleaseSession - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Same functionality offered by OrtApi::CreateSession except that a container that contains + * pre-packed weights' buffers is written into/read from by the created session. + * This is useful when used in conjunction with OrtApi::AddInitializer which injects + * shared initializer info into sessions. Wherever possible, the pre-packed versions of these + * shared initializers are cached in this container so that multiple sessions can just re-use + * these instead of duplicating these in memory. + * + * \param[in] env OrtEnv instance instance + * \param[in] model_path Null terminated string of the path (wchar on Windows, char otherwise) + * \param[in] options + * \param[in] prepacked_weights_container + * \param[out] out Newly created ::OrtSession. Must be freed with OrtApi::ReleaseSession + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateSessionWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, _Outptr_ OrtSession** out); /** \brief Create session from memory with prepacked weights container - * - * Same functionality offered by OrtApi::CreateSessionFromArray except that a container that contains - * pre-packed weights' buffers is written into/read from by the created session. - * This is useful when used in conjunction with OrtApi::AddInitializer which injects - * shared initializer info into sessions. Wherever possible, the pre-packed versions of these - * shared initializers are cached in this container so that multiple sessions can just re-use - * these instead of duplicating these in memory. - * - * \param[in] env - * \param[in] model_data Array of bytes holding the model - * \param[in] model_data_length Number of bytes in `model_data_model` - * \param[in] options - * \param[in] prepacked_weights_container - * \param[out] out Newly created ::OrtSession. Must be freed with OrtApi::ReleaseSession - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Same functionality offered by OrtApi::CreateSessionFromArray except that a container that contains + * pre-packed weights' buffers is written into/read from by the created session. + * This is useful when used in conjunction with OrtApi::AddInitializer which injects + * shared initializer info into sessions. Wherever possible, the pre-packed versions of these + * shared initializers are cached in this container so that multiple sessions can just re-use + * these instead of duplicating these in memory. + * + * \param[in] env + * \param[in] model_data Array of bytes holding the model + * \param[in] model_data_length Number of bytes in `model_data_model` + * \param[in] options + * \param[in] prepacked_weights_container + * \param[out] out Newly created ::OrtSession. Must be freed with OrtApi::ReleaseSession + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateSessionFromArrayWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, @@ -2703,22 +2913,22 @@ struct OrtApi { /// @{ /** \brief Append TensorRT execution provider to the session options - * - * If TensorRT is not available (due to a non TensorRT enabled build), this function will return failure. - * - * This is slightly different from OrtApi::SessionOptionsAppendExecutionProvider_TensorRT, it takes an - * ::OrtTensorRTProviderOptions which is publicly defined. This takes an opaque ::OrtTensorRTProviderOptionsV2 - * which must be created with OrtApi::CreateTensorRTProviderOptions. - * - * For OrtApi::SessionOptionsAppendExecutionProvider_TensorRT, the user needs to instantiate ::OrtTensorRTProviderOptions - * as well as allocate/release buffers for some members of ::OrtTensorRTProviderOptions. - * Here, OrtApi::CreateTensorRTProviderOptions and Ortapi::ReleaseTensorRTProviderOptions will do the memory management for you. - * - * \param[in] options - * \param[in] tensorrt_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * If TensorRT is not available (due to a non TensorRT enabled build), this function will return failure. + * + * This is slightly different from OrtApi::SessionOptionsAppendExecutionProvider_TensorRT, it takes an + * ::OrtTensorRTProviderOptions which is publicly defined. This takes an opaque ::OrtTensorRTProviderOptionsV2 + * which must be created with OrtApi::CreateTensorRTProviderOptions. + * + * For OrtApi::SessionOptionsAppendExecutionProvider_TensorRT, the user needs to instantiate ::OrtTensorRTProviderOptions + * as well as allocate/release buffers for some members of ::OrtTensorRTProviderOptions. + * Here, OrtApi::CreateTensorRTProviderOptions and Ortapi::ReleaseTensorRTProviderOptions will do the memory management for you. + * + * \param[in] options + * \param[in] tensorrt_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_TensorRT_V2, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options); @@ -2727,50 +2937,50 @@ struct OrtApi { /// @{ /** \brief Create an OrtTensorRTProviderOptionsV2 - * - * \param[out] out Newly created ::OrtTensorRTProviderOptionsV2. Must be released with OrtApi::ReleaseTensorRTProviderOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[out] out Newly created ::OrtTensorRTProviderOptionsV2. Must be released with OrtApi::ReleaseTensorRTProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateTensorRTProviderOptions, _Outptr_ OrtTensorRTProviderOptionsV2** out); /** \brief Set options in a TensorRT Execution Provider. - * - * Please refer to https://www.onnxruntime.ai/docs/reference/execution-providers/TensorRT-ExecutionProvider.html#c-api-example - * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2 - * and value should be its related range. - * - * For example, key="trt_max_workspace_size" and value="2147483648" - * - * \param[in] tensorrt_options - * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys - * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values - * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Please refer to https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc + * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2 + * and value should be its related range. Recreates the options and only sets the supplied values. + * + * For example, key="trt_max_workspace_size" and value="2147483648" + * + * \param[in] tensorrt_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(UpdateTensorRTProviderOptions, _Inout_ OrtTensorRTProviderOptionsV2* tensorrt_options, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); /** \brief Get serialized TensorRT provider options string. - * - * For example, "trt_max_workspace_size=2147483648;trt_max_partition_iterations=10;trt_int8_enable=1;......" - * - * \param tensorrt_options - OrTensorRTProviderOptionsV2 instance - * \param allocator - a ptr to an instance of OrtAllocator obtained with OrtApi::CreateAllocator or OrtApi::GetAllocatorWithDefaultOptions - * the specified allocator will be used to allocate continuous buffers for output strings and lengths. - * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * For example, "trt_max_workspace_size=2147483648;trt_max_partition_iterations=10;trt_int8_enable=1;......" + * + * \param tensorrt_options - OrtTensorRTProviderOptionsV2 instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with OrtApi::CreateAllocator or OrtApi::GetAllocatorWithDefaultOptions + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetTensorRTProviderOptionsAsString, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); /** \brief Release an ::OrtTensorRTProviderOptionsV2 - * - * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does - */ + * + * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does + */ void(ORT_API_CALL* ReleaseTensorRTProviderOptions)(_Frees_ptr_opt_ OrtTensorRTProviderOptionsV2* input); /// @} @@ -2778,11 +2988,11 @@ struct OrtApi { /// @{ /** \brief Enable custom operators - * - * See onnxruntime-extensions: https://github.com/microsoft/onnxruntime-extensions.git - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * See onnxruntime-extensions: https://github.com/microsoft/onnxruntime-extensions.git + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options); /// @} @@ -2790,32 +3000,32 @@ struct OrtApi { /// @{ /** \brief Register a custom allocator - * - * Enables sharing between multiple sessions that use the same env instance. - * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. - * - * The behavior of this is exactly the same as OrtApi::CreateAndRegisterAllocator except - * instead of ORT creating an allocator based on provided info, in this case - * ORT uses the user-provided custom allocator. - * See https://onnxruntime.ai/docs/reference/api/c-api.html for details. - * - * \param[in] env - * \param[in] allocator User provided allocator - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Enables sharing between multiple sessions that use the same env instance. + * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. + * + * The behavior of this is exactly the same as OrtApi::CreateAndRegisterAllocator except + * instead of ORT creating an allocator based on provided info, in this case + * ORT uses the user-provided custom allocator. + * See https://onnxruntime.ai/docs/get-started/with-c.html for details. + * + * \param[in] env + * \param[in] allocator User provided allocator + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(RegisterAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator* allocator); /** \brief Unregister a custom allocator - * - * It is an error if you provide an ::OrtMemoryInfo not corresponding to any - * registered allocators for sharing. - * - * \param[in] env - * \param[in] mem_info - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * It is an error if you provide an ::OrtMemoryInfo not corresponding to any + * registered allocators for sharing. + * + * \param[in] env + * \param[in] mem_info + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(UnregisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info); @@ -2824,235 +3034,238 @@ struct OrtApi { /// @{ /** \brief Sets *out to 1 iff an ::OrtValue is a SparseTensor, and 0 otherwise - * - * \param[in] value existing ::OrtValue - * \param[out] out unless an error occurs, contains 1 iff the value contains an instance - * of sparse tensor or 0 otherwise. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] value existing ::OrtValue + * \param[out] out unless an error occurs, contains 1 iff the value contains an instance + * of sparse tensor or 0 otherwise. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(IsSparseTensor, _In_ const OrtValue* value, _Out_ int* out); /** \brief Create an ::OrtValue with a sparse tensor that is empty. - * - * Use FillSparseTensor() functions to populate sparse tensor with non-zero values and - * format specific indices data. - * Use ReleaseValue to destroy the sparse tensor, this will also release the buffer inside the output value - * if any was allocated. - * \param[in,out] allocator allocator to use when performing an allocation. Allocation will be performed - * by FillSparseTensor() APIs. The lifespan of the allocator instance must eclipse the lifespan - * this sparse tensor instance as the same allocator will be used to free memory. - * \param[in] dense_shape shape of the original dense tensor - * \param[in] dense_shape_len number of shape dimensions being passed - * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx - * \param[out] out Should be freed by calling ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * Use FillSparseTensor() functions to populate sparse tensor with non-zero values and + * format specific indices data. + * Use ReleaseValue to destroy the sparse tensor, this will also release the buffer inside the output value + * if any was allocated. + * \param[in,out] allocator allocator to use when performing an allocation. Allocation will be performed + * by FillSparseTensor() APIs. The lifespan of the allocator instance must eclipse the lifespan + * this sparse tensor instance as the same allocator will be used to free memory. + * \param[in] dense_shape shape of the original dense tensor + * \param[in] dense_shape_len number of shape dimensions being passed + * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx + * \param[out] out Should be freed by calling ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* dense_shape, size_t dense_shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out); /** - * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. - * This will allocate required memory and copy the supplied NNZ values and COO indices into that memory allocation. - * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. - * - * \param[in,out] ort_value ::OrtValue to populate with data - * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified - * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. - * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. - * \param[in] values_shape pointer to values shape array - * \param[in] values_shape_len length of the values_shape - * \param[in] values pointer to an array of values. For strings, pass const char**. - * \param[in] indices_data pointer to a location of COO indices - * \param[in] indices_num number of COO indices - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. + * This will allocate required memory and copy the supplied NNZ values and COO indices into that memory allocation. + * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. + * + * \param[in,out] ort_value ::OrtValue to populate with data + * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified + * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. + * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. + * \param[in] values_shape pointer to values shape array + * \param[in] values_shape_len length of the values_shape + * \param[in] values pointer to an array of values. For strings, pass const char**. + * \param[in] indices_data pointer to a location of COO indices + * \param[in] indices_num number of COO indices + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(FillSparseTensorCoo, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, _In_ const int64_t* indices_data, size_t indices_num); /** - * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. - * This will allocate required memory and copy the supplied NNZ values and CSR indices into that memory allocation. - * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. - * - * \param[in,out] ort_value ::OrtValue to populate with data - * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified - * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. - * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. - * \param[in] values_shape pointer to values shape array - * \param[in] values_shape_len length of the values_shape - * \param[in] values - pointer to an array of values. For strings, pass const char**. - * \param[in] inner_indices_data pointer to a location of CSR inner indices - * \param[in] inner_indices_num number of CSR inner indices - * \param[in] outer_indices_data pointer to a location of CSR outer indices - * \param[in] outer_indices_num number of CSR outer indices - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. + * This will allocate required memory and copy the supplied NNZ values and CSR indices into that memory allocation. + * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. + * + * \param[in,out] ort_value ::OrtValue to populate with data + * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified + * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. + * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. + * \param[in] values_shape pointer to values shape array + * \param[in] values_shape_len length of the values_shape + * \param[in] values - pointer to an array of values. For strings, pass const char**. + * \param[in] inner_indices_data pointer to a location of CSR inner indices + * \param[in] inner_indices_num number of CSR inner indices + * \param[in] outer_indices_data pointer to a location of CSR outer indices + * \param[in] outer_indices_num number of CSR outer indices + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(FillSparseTensorCsr, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, _In_ const int64_t* inner_indices_data, size_t inner_indices_num, _In_ const int64_t* outer_indices_data, size_t outer_indices_num); /** - * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. - * This will allocate required memory and copy the supplied NNZ values and BlockSparse indices into that memory allocation. - * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. - * - * \param[in,out] ort_value ::OrtValue to populate with data - * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified - * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. - * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. - * \param[in] values_shape - * \param[in] values_shape_len - * \param[in] values structure with values information - * \param[in] indices_shape_data pointer to a location of indices shape - * \param[in] indices_shape_len length of the block sparse indices shape - * \param[in] indices_data pointer to a location of indices data. Shape will determine the length of the indices data. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. + * This will allocate required memory and copy the supplied NNZ values and BlockSparse indices into that memory allocation. + * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. + * + * \param[in,out] ort_value ::OrtValue to populate with data + * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified + * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. + * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. + * \param[in] values_shape + * \param[in] values_shape_len + * \param[in] values structure with values information + * \param[in] indices_shape_data pointer to a location of indices shape + * \param[in] indices_shape_len length of the block sparse indices shape + * \param[in] indices_data pointer to a location of indices data. Shape will determine the length of the indices data. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(FillSparseTensorBlockSparse, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, _In_ const int64_t* indices_shape_data, size_t indices_shape_len, _In_ const int32_t* indices_data); /** - * Create an ::OrtValue with a sparse tensor. This is the first step. - * Next, use UseIndices() functions to supply sparse tensor with - * format specific indices data and set its sparse format to a specific enum value. - * This will not perform memory allocations. It will - * use supplied user buffer which should outlive the created sparse tensor. - * Use OrtApi::ReleaseValue to destroy the sparse tensor. It would not release the supplied values buffer. - * This function can not be used to map strings from the user allocated memory. Strings must always be copied - * and have UTF-8 encoding. Therefore, use OrtApi::CreateSparseTensorAsOrtValue above and then fill it with data - * using appropriate Make*() function. - * - * \param[in] info memory info where sparse values reside. - * \param[in,out] p_data pointer to a user allocated buffer with values. To create a full sparse tensor with no non-zero - * values, pass nullptr - * \param[in] dense_shape shape of the original dense tensor - * \param[in] dense_shape_len number of shape dimensions being passed - * \param[in] values_shape shape of the values data. To create a fully sparse tensor with no non-zero values, - * pass {0} shape. - * \param[in] values_shape_len number of values shape dimensions - * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx - * \param[out] out Should be freed by calling ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * Create an ::OrtValue with a sparse tensor. This is the first step. + * Next, use UseIndices() functions to supply sparse tensor with + * format specific indices data and set its sparse format to a specific enum value. + * This will not perform memory allocations. It will + * use supplied user buffer which should outlive the created sparse tensor. + * Use OrtApi::ReleaseValue to destroy the sparse tensor. It would not release the supplied values buffer. + * This function can not be used to map strings from the user allocated memory. Strings must always be copied + * and have UTF-8 encoding. Therefore, use OrtApi::CreateSparseTensorAsOrtValue above and then fill it with data + * using appropriate Make*() function. + * + * \param[in] info memory info where sparse values reside. + * \param[in,out] p_data pointer to a user allocated buffer with values. To create a full sparse tensor with no non-zero + * values, pass nullptr + * \param[in] dense_shape shape of the original dense tensor + * \param[in] dense_shape_len number of shape dimensions being passed + * \param[in] values_shape shape of the values data. To create a fully sparse tensor with no non-zero values, + * pass {0} shape. + * \param[in] values_shape_len number of values shape dimensions + * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx + * \param[out] out Should be freed by calling ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(CreateSparseTensorWithValuesAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, _In_ const int64_t* dense_shape, size_t dense_shape_len, _In_ const int64_t* values_shape, size_t values_shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out); /** - * This assigns Coo format indices to the SparseTensor that was created by - * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to - * ORT_SPARSE_COO. This will not allocate any additional memory for data. The life span of - * indices_data buffer should eclipse the life span of this ::OrtValue. - * - * \param[in,out] ort_value ::OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue - * \param[in,out] indices_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors. - * \param[in] indices_num number of COO indices. Should either be 0 for fully sparse tensors, be equal - * to the number of nnz values specified to OrtApi::CreateSparseTensorWithValuesAsOrtValue for 1-D {nnz} indices or - * be twice as number of nnz values for a 2-D indices {nnz, 2} - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * This assigns Coo format indices to the SparseTensor that was created by + * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to + * ORT_SPARSE_COO. This will not allocate any additional memory for data. The life span of + * indices_data buffer should eclipse the life span of this ::OrtValue. + * + * \param[in,out] ort_value ::OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue + * \param[in,out] indices_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors. + * \param[in] indices_num number of COO indices. Should either be 0 for fully sparse tensors, be equal + * to the number of nnz values specified to OrtApi::CreateSparseTensorWithValuesAsOrtValue for 1-D {nnz} indices or + * be twice as number of nnz values for a 2-D indices {nnz, 2} + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(UseCooIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* indices_data, size_t indices_num); /** - * The assigns CSR format indices to the SparseTensor that was created by - * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to - * ORT_SPARSE_CSRC. This will not allocate any additional memory for data. The life spans of - * inner_data and outer_data buffers should eclipse the life span of this ::OrtValue. - * - * \param[in,out] ort_value ::OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue - * \param[in,out] inner_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors. - * \param[in] inner_num number of inner CSR indices. Should either be 0 for fully sparse tensors or be equal - * to the number of nnz values specified to OrtApi::CreateSparseTensorWithValuesAsOrtValue. - * \param[in,out] outer_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors. - * \param[in] outer_num number of CSR outer indices. Should either be 0 for fully sparse tensors or - * equal to rows + 1 of the dense shape. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * The assigns CSR format indices to the SparseTensor that was created by + * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to + * ORT_SPARSE_CSRC. This will not allocate any additional memory for data. The life spans of + * inner_data and outer_data buffers should eclipse the life span of this ::OrtValue. + * + * \param[in,out] ort_value ::OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue + * \param[in,out] inner_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors. + * \param[in] inner_num number of inner CSR indices. Should either be 0 for fully sparse tensors or be equal + * to the number of nnz values specified to OrtApi::CreateSparseTensorWithValuesAsOrtValue. + * \param[in,out] outer_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors. + * \param[in] outer_num number of CSR outer indices. Should either be 0 for fully sparse tensors or + * equal to rows + 1 of the dense shape. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(UseCsrIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* inner_data, size_t inner_num, _Inout_ int64_t* outer_data, size_t outer_num); /** - * The assigns BlockSparse format indices to the SparseTensor that was created by - * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to - * ORT_SPARSE_BLOCK_SPARSE. This will not allocate any additional memory for data. The life span of - * indices_data buffer must eclipse the lifespan of this ::OrtValue. - * - * \param[in,out] ort_value OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue - * \param[in] indices_shape pointer to indices shape. Use {0} for fully sparse tensors - * \param[in] indices_shape_len length of the indices shape - * \param[in,out] indices_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * The assigns BlockSparse format indices to the SparseTensor that was created by + * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to + * ORT_SPARSE_BLOCK_SPARSE. This will not allocate any additional memory for data. The life span of + * indices_data buffer must eclipse the lifespan of this ::OrtValue. + * + * \param[in,out] ort_value OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue + * \param[in] indices_shape pointer to indices shape. Use {0} for fully sparse tensors + * \param[in] indices_shape_len length of the indices shape + * \param[in,out] indices_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(UseBlockSparseIndices, _Inout_ OrtValue* ort_value, const int64_t* indices_shape, size_t indices_shape_len, _Inout_ int32_t* indices_data); /** \brief Returns sparse tensor format enum iff a given ort value contains an instance of sparse tensor. - * - * \param[in] ort_value ::OrtValue that contains an instance of sparse tensor - * \param[out] out pointer to out parameter - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] ort_value ::OrtValue that contains an instance of sparse tensor + * \param[out] out pointer to out parameter + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetSparseTensorFormat, _In_ const OrtValue* ort_value, _Out_ enum OrtSparseFormat* out); /** \brief Returns data type and shape of sparse tensor values (nnz) iff ::OrtValue contains a SparseTensor. - * - * \param[in] ort_value An ::OrtValue that contains a fully constructed sparse tensor - * \param[out] out Must be freed by OrtApi::ReleaseTensorTypeAndShapeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] ort_value An ::OrtValue that contains a fully constructed sparse tensor + * \param[out] out Must be freed by OrtApi::ReleaseTensorTypeAndShapeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetSparseTensorValuesTypeAndShape, _In_ const OrtValue* ort_value, _Outptr_ OrtTensorTypeAndShapeInfo** out); /** \brief Returns numeric data for sparse tensor values (nnz). For string values use GetStringTensor*(). - * - * \param[in] ort_value an instance of ::OrtValue containing sparse tensor - * \param[out] out returns a pointer to values data. Do not attempt to free this ptr. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] ort_value an instance of ::OrtValue containing sparse tensor + * \param[out] out returns a pointer to values data. Do not attempt to free this ptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outptr_ const void** out); /** \brief Returns data type, shape for the type of indices specified by indices_format. - * - * \param[in] ort_value ::OrtValue containing sparse tensor. - * \param[in] indices_format One of the indices formats. It is an error to request a format that the sparse - * tensor does not contain. - * \param[out] out an instance of ::OrtTensorTypeAndShapeInfo. Must be freed by OrtApi::ReleaseTensorTypeAndShapeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] ort_value ::OrtValue containing sparse tensor. + * \param[in] indices_format One of the indices formats. It is an error to request a format that the sparse + * tensor does not contain. + * \param[out] out an instance of ::OrtTensorTypeAndShapeInfo. Must be freed by OrtApi::ReleaseTensorTypeAndShapeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out); /** \brief Returns indices data for the type of the indices specified by indices_format - * - * \param[in] ort_value ::OrtValue containing sparse tensor. - * \param[in] indices_format One of the indices formats. It is an error to request a format that the sparse tensor does not contain. - * \param[out] num_indices Pointer to where the number of indices entries is returned - * \param[out] indices Returned pointer to the indices data. Do not free the returned pointer as it refers to internal data owned by the ::OrtValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] ort_value ::OrtValue containing sparse tensor. + * \param[in] indices_format One of the indices formats. It is an error to request a format that the sparse tensor does not contain. + * \param[out] num_indices Pointer to where the number of indices entries is returned + * \param[out] indices Returned pointer to the indices data. Do not free the returned pointer as it refers to internal data owned by the ::OrtValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices); + /// @} + /// \name OrtSessionOptions + /// @{ /** * \brief Sets out to 1 iff an optional type OrtValue has an element, 0 otherwise (OrtValue is None) * Use this API to find if the optional type OrtValue is None or not. * If the optional type OrtValue is not None, use the OrtValue just like any other OrtValue. - * For example, if you get an OrtValue that corresponds to Optional(tensor) and + * For example, if you get an OrtValue that corresponds to Optional(tensor) and * if HasValue() returns true, use it as tensor and so on. * \param[in] value Input OrtValue. @@ -3061,25 +3274,30 @@ struct OrtApi { * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(HasValue, _In_ const OrtValue* value, _Out_ int* out); + /// @} /// \name OrtKernelContext + /// Custom operator APIs. /// @{ - /** \brief Used for custom operators, gets the GPU compute stream to use to launch the custom a GPU kernel - * \see ::OrtCustomOp - * \param[context] OrtKernelContext instance - * \param[out] Returns pointer to a GPU compute stream that can be used to launch the custom GPU kernel. - * If retrieving the GPU compute stream is not relevant (GPU not enabled in the build, kernel partitioned to - * some other EP), then a nullptr is returned as the output param. - * Do not free or mutate the returned pointer as it refers to internal data owned by the underlying session. - * Only use it for custom kernel launching. - */ + + /** \brief Used for custom operators, gets the GPU compute stream to use to launch the custom a GPU kernel + * \see ::OrtCustomOp + * \param[in] context OrtKernelContext instance + * \param[out] out Returns pointer to a GPU compute stream that can be used to launch the custom GPU kernel. + * If retrieving the GPU compute stream is not relevant (GPU not enabled in the build, kernel partitioned to + * some other EP), then a nullptr is returned as the output param. + * Do not free or mutate the returned pointer as it refers to internal data owned by the underlying session. + * Only use it for custom kernel launching. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out); /// @} /// \name GetTensorMemoryInfo /// @{ /** \brief Returns a pointer to the ::OrtMemoryInfo of a Tensor - * \param[in] ort_value ::OrtValue containing tensor. + * \param[in] value ::OrtValue containing tensor. * \param[out] mem_info ::OrtMemoryInfo of the tensor. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtValue * * \snippet{doc} snippets.dox OrtStatus Return Value @@ -3090,7 +3308,7 @@ struct OrtApi { /// \name GetExecutionProviderApi /// @{ /** \brief Get a pointer to the requested version of the Execution Provider specific - * API extensions to the OrtApi + * API extensions to the OrtApi * \param[in] provider_name The name of the execution provider name. Currently only the following * values are supported: "DML". * \param[in] version Must be ::ORT_API_VERSION. @@ -3106,144 +3324,1594 @@ struct OrtApi { /// \name SessionOptions /// @{ /** \brief Set custom thread creation function - * - * \param[in] session options - * \param[in] custom thread creation function - * - * * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] options Session options + * \param[in] ort_custom_create_thread_fn Custom thread creation function + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); - /** \brief Set creation options for custom thread - * - * \param[in] session options - * \param[in] custom thread creation options (can be nullptr) - * - * * \snippet{doc} snippets.dox OrtStatus Return Value - */ + /** \brief Set creation options for custom thread + * + * \param[in] options Session options + * \param[in] ort_custom_thread_creation_options Custom thread creation options (can be nullptr) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options); /** \brief Set custom thread join function - * - * \param[in] session options - * \param[in] custom join thread function, must not be nullptr when ort_custom_create_thread_fn is set - * - * * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[in] options Session options + * \param[in] ort_custom_join_thread_fn Custom join thread function, must not be nullptr when ort_custom_create_thread_fn is set + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); /// @} /// \name OrtThreadingOptions /// @{ /** \brief Set custom thread creation function for global thread pools - * - * \param[inout] tp_options - * \param[in] custom thread creation function - * - * * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[inout] tp_options + * \param[in] ort_custom_create_thread_fn Custom thread creation function + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); /** \brief Set custom thread creation options for global thread pools - * - * \param[inout] tp_options - * \param[in] custom thread creation options (can be nullptr) - * - * * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[inout] tp_options + * \param[in] ort_custom_thread_creation_options Custom thread creation options (can be nullptr) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options); /** \brief Set custom thread join function for global thread pools - * - * \param[inout] tp_options - * \param[in] custom thread join function, must not be nullptr when global ort_custom_create_thread_fn is set - * - * * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * + * \param[inout] tp_options + * \param[in] ort_custom_join_thread_fn Custom thread join function, must not be nullptr when global ort_custom_create_thread_fn is set + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); /// @} /** \brief Synchronize bound inputs. The call may be necessary for some providers, such as cuda, - * in case the system that allocated bound memory operated on a different stream. However, the - * operation is provider specific and could be a no-op. - * - * \param[inout] binding_ptr - * - * * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * in case the system that allocated bound memory operated on a different stream. However, the + * operation is provider specific and could be a no-op. + * + * \param[inout] binding_ptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SynchronizeBoundInputs, _Inout_ OrtIoBinding* binding_ptr); /** \brief Synchronize bound outputs. The call may be necessary for some providers, such as cuda, - * in case the system that allocated bound memory operated on a different stream. However, the - * operation is provider specific and could be a no-op. - * - * \param[inout] binding_ptr - * - * * \snippet{doc} snippets.dox OrtStatus Return Value - */ + * in case the system that allocated bound memory operated on a different stream. However, the + * operation is provider specific and could be a no-op. + * + * \param[inout] binding_ptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ ORT_API2_STATUS(SynchronizeBoundOutputs, _Inout_ OrtIoBinding* binding_ptr); -}; - -/* - * Steps to use a custom op: - * 1 Create an OrtCustomOpDomain with the domain name used by the custom ops - * 2 Create an OrtCustomOp structure for each op and add them to the domain - * 3 Call OrtAddCustomOpDomain to add the custom domain of ops to the session options -*/ -#define OrtCustomOpApi OrtApi -// Specifies some characteristics of inputs/outputs of custom ops: -// Specify if the inputs/outputs are one of: -// 1) Non-optional (input/output must be present in the node) -// 2) Optional (input/output may be absent in the node) -typedef enum OrtCustomOpInputOutputCharacteristic { - // TODO: Support 'Variadic' inputs/outputs - INPUT_OUTPUT_REQUIRED = 0, - INPUT_OUTPUT_OPTIONAL, -} OrtCustomOpInputOutputCharacteristic; + /// \name OrtSessionOptions + /// @{ -/* - * The OrtCustomOp structure defines a custom op's schema and its kernel callbacks. The callbacks are filled in by - * the implementor of the custom op. -*/ -struct OrtCustomOp { - uint32_t version; // Must be initialized to ORT_API_VERSION + /** \brief Append CUDA execution provider to the session options + * + * If CUDA is not available (due to a non CUDA enabled build), this function will return failure. + * + * This is slightly different from OrtApi::SessionOptionsAppendExecutionProvider_CUDA, it takes an + * ::OrtCUDAProviderOptions which is publicly defined. This takes an opaque ::OrtCUDAProviderOptionsV2 + * which must be created with OrtApi::CreateCUDAProviderOptions. + * + * For OrtApi::SessionOptionsAppendExecutionProvider_CUDA, the user needs to instantiate ::OrtCUDAProviderOptions + * as well as allocate/release buffers for some members of ::OrtCUDAProviderOptions. + * Here, OrtApi::CreateCUDAProviderOptions and Ortapi::ReleaseCUDAProviderOptions will do the memory management for you. + * + * \param[in] options + * \param[in] cuda_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA_V2, + _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptionsV2* cuda_options); - // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below. - void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, - _In_ const OrtKernelInfo* info); + /// @} + /// \name OrtCUDAProviderOptionsV2 + /// @{ - // Returns the name of the op - const char*(ORT_API_CALL* GetName)(_In_ const struct OrtCustomOp* op); + /** \brief Create an OrtCUDAProviderOptionsV2 + * + * \param[out] out Newly created ::OrtCUDAProviderOptionsV2. Must be released with OrtApi::ReleaseCudaProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(CreateCUDAProviderOptions, _Outptr_ OrtCUDAProviderOptionsV2** out); - // Returns the type of the execution provider, return nullptr to use CPU execution provider - const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ const struct OrtCustomOp* op); + /** \brief Set options in a CUDA Execution Provider. + * + * Please refer to https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options + * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtCUDAProviderOptionsV2 + * and value should be its related range. Recreates the options and only sets the supplied values. + * + * For example, key="device_id" and value="0" + * + * \param[in] cuda_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(UpdateCUDAProviderOptions, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); - // Returns the count and types of the input & output tensors - ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - size_t(ORT_API_CALL* GetInputTypeCount)(_In_ const struct OrtCustomOp* op); - ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op); + /** + * Get serialized CUDA provider options string. + * + * For example, "device_id=0;arena_extend_strategy=0;......" + * + * \param cuda_options - OrtCUDAProviderOptionsV2 instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(GetCUDAProviderOptionsAsString, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); - // Op kernel callbacks - void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context); - void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); + /** \brief Release an ::OrtCUDAProviderOptionsV2 + * + * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does + * + * \since Version 1.11. + */ + void(ORT_API_CALL* ReleaseCUDAProviderOptions)(_Frees_ptr_opt_ OrtCUDAProviderOptionsV2* input); - // Returns the characteristics of the input & output tensors - OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); -}; + /// @} + + /** \brief Append MIGraphX provider to session options + * + * If MIGraphX is not available (due to a non MIGraphX enabled build, or if MIGraphX is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] migraphx_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_MIGraphX, + _In_ OrtSessionOptions* options, _In_ const OrtMIGraphXProviderOptions* migraphx_options); + + /** \brief Replace initialized Tensors with external data with the data provided in initializers. + * + * The function will find the initialized TensorProtos with external data in the graph with the provided names and + * replace them with the provided tensors. The API verifies that the TensorProto being replaced + * has an external data reference and has the same name, dimensions and data type as its replacement. The replacement + * will occur before any of the optimizations take place. The data will be copied into the graph + * since TensorProto can't refer to the user provided buffers. + * + * Once the model has been loaded, the OrtValue(s) added to SessionOptions instance will be removed + * from the internal SessionOptions copy to save memory, the user provided buffers can then be deallocated + * and the SessionOptions instance that refers to them can be destroyed. + * + * \param[in] options + * \param[in] initializer_names Array of null terminated UTF-8 encoded strings of the initializers names. + * \param[in] initializers Array of ::OrtValue type + * \param[in] num_initializers Number of elements in the initializer_names and initializers + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.12. + */ + ORT_API2_STATUS(AddExternalInitializers, _In_ OrtSessionOptions* options, + _In_reads_(num_initializers) const char* const* initializer_names, + _In_reads_(num_initializers) const OrtValue* const* initializers, size_t num_initializers); + + /** \brief: Create attribute of onnxruntime operator + * + * \param[in] name Name of the attribute + * \param[in] data Data content of the attribute + * \param[in] len Number of bytes stored in data + * \param[in] type Data type + * \param[out] op_attr Attribute that has been created, which must be released by OrtApi::ReleaseOpAttr + * + * \since Version 1.12. + */ + ORT_API2_STATUS(CreateOpAttr, + _In_ const char* name, + _In_ const void* data, + _In_ int len, + _In_ OrtOpAttrType type, + _Outptr_ OrtOpAttr** op_attr); + + /* \brief: Release op attribute + * + * \param[in] opAttr Attribute created by OrtApi::CreateOpAttr + * + * \since Version 1.12. + */ + ORT_CLASS_RELEASE(OpAttr); + + /** \brief: Create onnxruntime native operator + * + * \param[in] info Kernel info + * \param[in] op_name Operator name + * \param[in] domain Operator domain + * \param[in] version Operator opset version + * \param[in] type_constraint_names Name of the type contraints, such as "T" or "T1" + * \param[in] type_constraint_values Type of each contraints + * \param[in] type_constraint_count Number of contraints + * \param[in] attr_values Attributes used to initialize the operator + * \param[in] attr_count Number of the attributes + * \param[in] input_count Number of inputs + * \param[in] output_count Number of outputs + * \param[out] ort_op Operator that has been created + * + * \since Version 1.12. + */ + ORT_API2_STATUS(CreateOp, + _In_ const OrtKernelInfo* info, + _In_z_ const char* op_name, + _In_z_ const char* domain, + int version, + _In_reads_(type_constraint_count) const char** type_constraint_names, + _In_reads_(type_constraint_count) const ONNXTensorElementDataType* type_constraint_values, + int type_constraint_count, + _In_reads_(attr_count) const OrtOpAttr* const* attr_values, + int attr_count, + int input_count, + int output_count, + _Outptr_ OrtOp** ort_op); + + /** \brief: Invoke the operator created by OrtApi::CreateOp + * The inputs must follow the order as specified in onnx specification + * + * \param[in] context Kernel context + * \param[in] ort_op Operator that has been created + * \param[in] input_values Array of inputs + * \param[in] input_count Number of inputs + * \param[in] output_values Array of outputs + * \param[in] output_count Number of outputs + * + * \since Version 1.12. + */ + ORT_API2_STATUS(InvokeOp, + _In_ const OrtKernelContext* context, + _In_ const OrtOp* ort_op, + _In_ const OrtValue* const* input_values, + _In_ int input_count, + _Inout_ OrtValue* const* output_values, + _In_ int output_count); + + /* \brief: Release an onnxruntime operator + * + * \param[in] Op Operator created by OrtApi::CreateOp + * + * \since Version 1.12. + */ + ORT_CLASS_RELEASE(Op); + + /** \brief: Append execution provider to the session options. + * \param[in] options + * \param[in] provider_name - provider to add. + * \param[in] provider_options_keys - keys to configure the provider options + * \param[in] provider_options_values - values to configure the provider options + * \param[in] num_keys - number of keys passed in + * + * Currently supported providers: + * QNN + * SNPE + * XNNPACK + * + * Note: If an execution provider has a dedicated SessionOptionsAppendExecutionProvider_ function + * that should be used to add it. + * + * QNN supported keys: + * "backend_path": file path to QNN backend library. + * "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off. + * "profiling_file_path": QNN profiling file path if ETW not enabled. + * "rpc_control_latency": QNN RPC control latency. + * "vtcm_mb": QNN VTCM size in MB. default to 0(not set). + * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance", + * "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". + * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will + * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and + * may alter model/EP partitioning. Use only for debugging. + * "qnn_context_priority": QNN context priority, options: "low", "normal", "normal_high", "high". Default to "normal". + * "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. Available options: + * - "0": Default. + * - "1": Faster preparation time, less optimal graph. + * - "2": Longer preparation time, more optimal graph. + * - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details. + * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown). + * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options: + * - "0": Default (none). + * - "68" + * - "69" + * - "73" + * - "75" + * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). + * "enable_htp_fp16_precision": Used for float32 model for HTP backend. + * Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. + * - "0": With fp32 precision. + * - "1": Default. With fp16 precision. + * "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context. + * - "0": Default. Disabled. + * - "1": Enabled. + * "offload_graph_io_quantization": Offload graph input quantization and graph output dequantization to another + * execution provider (typically CPU EP). + * - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O. + * - "1": Enabled. + * + * SNPE supported keys: + * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", + * "DSP", "DSP_FIXED8_TF", "AIP_FIXED_TF", "AIP_FIXED8_TF". + * Mapping to SNPE Runtime_t definition: CPU, CPU_FLOAT32 => zdl::DlSystem::Runtime_t::CPU; + * GPU, GPU_FLOAT32_16_HYBRID => zdl::DlSystem::Runtime_t::GPU; + * GPU_FLOAT16 => zdl::DlSystem::Runtime_t::GPU_FLOAT16; + * DSP, DSP_FIXED8_TF => zdl::DlSystem::Runtime_t::DSP. + * AIP_FIXED_TF, AIP_FIXED8_TF => zdl::DlSystem::Runtime_t::AIP_FIXED_TF. + * "priority": execution priority, options: "low", "normal". + * "buffer_type": ITensor or user buffers, options: "ITENSOR", user buffer with different types - "TF8", "TF16", "UINT8", "FLOAT". + * "ITENSOR" -- default, ITensor which is float only. + * "TF8" -- quantized model required, "FLOAT" -- for both quantized or non-quantized model + * "enable_init_cache": enable SNPE init caching feature, set to 1 to enabled it. Disabled by default. + * If SNPE is not available (due to a non Snpe enabled build or its dependencies not being installed), this function will fail. + * + * XNNPACK supported keys: + * "intra_op_num_threads": number of thread-pool size to use for XNNPACK execution provider. + * default value is 0, which means to use the session thread-pool size. + * + * \since Version 1.12. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, + _In_ const char* provider_name, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /* \brief: Get a copy of kernel info + * + * \param[in] info Kernel info + * \param[out] info_copy Copy of kernel info + * + * \since Version 1.12. + */ + ORT_API2_STATUS(CopyKernelInfo, + _In_ const OrtKernelInfo* info, + _Outptr_ OrtKernelInfo** info_copy); + + /* \brief: Release kernel info + * + * \param[in] KernelInfo A copy of kernel info returned by CopyKernelInfo + * + * \since Version 1.12. + */ + ORT_CLASS_RELEASE(KernelInfo); + + /// \name Ort Training + /// @{ + /** \brief Gets the Training C Api struct + * + * Call this function to access the ::OrtTrainingApi structure that holds pointers to functions that enable + * training with onnxruntime. + * \note A NULL pointer will be returned and no error message will be printed if the training api + * is not supported with this build. A NULL pointer will be returned and an error message will be + * printed if the provided version is unsupported, for example when using a runtime older than the + * version created with this header file. + * + * \param[in] version Must be ::ORT_API_VERSION + * \return The ::OrtTrainingApi struct for the version requested. + * + * \since Version 1.13 + */ + const OrtTrainingApi*(ORT_API_CALL* GetTrainingApi)(uint32_t version)NO_EXCEPTION; + + /// @} + + /** \brief Append CANN provider to session options + * + * If CANN is not available (due to a non CANN enabled build, or if CANN is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] cann_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.13. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CANN, + _In_ OrtSessionOptions* options, _In_ const OrtCANNProviderOptions* cann_options); + + /** \brief Create an OrtCANNProviderOptions + * + * \param[out] out created ::OrtCANNProviderOptions. Must be released with OrtApi::ReleaseCANNProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.13. + */ + ORT_API2_STATUS(CreateCANNProviderOptions, _Outptr_ OrtCANNProviderOptions** out); + + /** \brief Set options in a CANN Execution Provider. + * + * \param[in] cann_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.13. + */ + ORT_API2_STATUS(UpdateCANNProviderOptions, _Inout_ OrtCANNProviderOptions* cann_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** \brief Get serialized CANN provider options string. + * + * \param[in] cann_options OrtCANNProviderOptions instance + * \param[in] allocator a ptr to an instance of OrtAllocator obtained with CreateAllocator() + * or GetAllocatorWithDefaultOptions(), the specified allocator will be used to allocate + * continuous buffers for output strings and lengths. + * \param[out] ptr is a UTF-8 null terminated string allocated using 'allocator'. + * The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.13. + */ + ORT_API2_STATUS(GetCANNProviderOptionsAsString, _In_ const OrtCANNProviderOptions* cann_options, + _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an OrtCANNProviderOptions + * + * \param[in] input The pointer of OrtCANNProviderOptions which will been deleted + * + * \since Version 1.13. + */ + void(ORT_API_CALL* ReleaseCANNProviderOptions)(_Frees_ptr_opt_ OrtCANNProviderOptions* input); + + /* \brief Get OrtDevice type from MemoryInfo + * + * \since Version 1.14 + */ + void(ORT_API_CALL* MemoryInfoGetDeviceType)(_In_ const OrtMemoryInfo* ptr, _Out_ OrtMemoryInfoDeviceType* out); + + /* \brief Update the OrtEnv instance with custom log severity level + * + * \param[in] ort_env The OrtEnv instance being used + * \param[in] log_severity_level The log severity level. + * + * \since Version 1.14. + */ + ORT_API2_STATUS(UpdateEnvWithCustomLogLevel, _In_ OrtEnv* ort_env, OrtLoggingLevel log_severity_level); + + /* \brief Set affinities for intra op threads + * + * Affinity string follows format: + * logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id + * Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to. + * e.g. 1,2,3;4,5 + * specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th. + * To ease the configuration, an "interval" is also allowed: + * e.g. 1-8;8-16;17-24 + * orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth. + * Note: + * 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, + * ort does not set affinity on the main thread which is started and managed by the calling app; + * 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors, + * an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group. + * Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary. + * + * \since Version 1.14 + */ + ORT_API2_STATUS(SetGlobalIntraOpThreadAffinity, _Inout_ OrtThreadingOptions* tp_options, const char* affinity_string); + + /** \brief Register custom ops from a shared library. + * + * Loads a shared library (.dll on windows, .so on linux, etc) named 'library_name' and looks for this entry point: + * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); + * It then passes in the provided session options to this function along with the api base. + * + * The handle to the loaded library is automatically released by ORT when the last OrtSession that references the + * library handle is released. If no OrtSession is created, then the library handle is released when the provided + * OrtSessionOptions is released. + * + * \param[in] options The session options. + * \param[in] library_name The name of the shared library to load and register. Refer to OS-specific dynamic library + * loading utilities (e.g., LoadLibraryEx on Windows or dlopen on Linux/MacOS) for information + * on the format of library names and search paths. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(RegisterCustomOpsLibrary_V2, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* library_name); + + /** \brief Register custom ops by calling a RegisterCustomOpsFn function. + * + * Searches for registration_func_name and if found calls it. + * + * The library containing the function must either be linked against or previously loaded by the executable. + * + * If you want ONNX Runtime to load the library and manage its lifetime, use RegisterCustomOpsLibrary_V2. + * + * RegisterCustomOpsUsingFunction can be used in scenarios where it may not be possible for ONNX Runtime to load + * the library from a path. e.g. mobile platforms where the library must be linked into the app. + * + * The registration function must have the signature of RegisterCustomOpsFn: + * OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); + * + * See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for details on how the registration + * function should be implemented. + * + * \param[in] options OrtSessionOptions that is passed through as the first argument in the call to the + * registration function. + * \param[in] registration_func_name Name of registration function to use. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(RegisterCustomOpsUsingFunction, _Inout_ OrtSessionOptions* options, + _In_ const char* registration_func_name); + + /// \name OrtKernelInfo + /// Custom operator APIs. + /// @{ + + /** \brief Get the number of inputs from ::OrtKernelInfo. + * + * Used in the CreateKernel callback of an OrtCustomOp to query the number of inputs + * during kernel/session creation. + * + * \param[in] info Instance of ::OrtKernelInfo. + * \param[out] out Pointer to variable assigned with the result on success. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetInputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out); + + /** \brief Get the number of outputs from ::OrtKernelInfo. + * + * Used in the CreateKernel callback of an OrtCustomOp to query the number of outputs + * during kernel/session creation. + * + * \param[in] info Instance of ::OrtKernelInfo. + * \param[out] out Pointer to variable assigned with the result on success. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetOutputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out); + + /** \brief Get the name of a ::OrtKernelInfo's input. + * + * Used in the CreateKernel callback of an OrtCustomOp to query an input's name + * during kernel/session creation. + * + * If `out` is nullptr, the value of `size` is set to the size of the name + * string (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the name string's size, + * the value of `size` is set to the true size of the string (including null-terminator), + * the provided memory is filled with the string's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string + * and a failure status is returned. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[in] index The index of the input name to get. Returns a failure status if out-of-bounds. + * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the input's name. + * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, + _Inout_ size_t* size); + + /** \brief Get the name of a ::OrtKernelInfo's output. + * + * Used in the CreateKernel callback of an OrtCustomOp to query an output's name + * during kernel/session creation. + * + * If `out` is nullptr, the value of `size` is set to the size of the name + * string (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the name string's size, + * the value of `size` is set to the true size of the string (including null-terminator), + * the provided memory is filled with the string's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string + * and a failure status is returned. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[in] index The index of the output name to get. Returns a failure status if out-of-bounds. + * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the output's + * name. + * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, + _Inout_ size_t* size); + + /** \brief Get the type information for a ::OrtKernelInfo's input. + * + * Used in the CreateKernel callback of an OrtCustomOp to query the shape and type information + * of an input during kernel/session creation. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[in] index Which input to get the type information for + * \param[out] type_info Pointer set to the resulting ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetInputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Get the type information for a ::OrtKernelInfo's output. + * + * Used in the CreateKernel callback of an OrtCustomOp to query the shape and type information + * of an output during kernel/session creation. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[in] index Which input to get the type information for + * \param[out] type_info Pointer set to the resulting ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetOutputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Get a ::OrtValue tensor stored as an attribute in the graph node. + * + * Used in the CreateKernel callback of an OrtCustomOp to get a tensor attribute. + * + * \param[in] info ::OrtKernelInfo instance. + * \param[in] name UTF-8 null-terminated string representing the attribute's name. + * \param[in] allocator Allocator used to allocate the internal tensor state. + * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue, + * which will also free internal tensor state allocated with the provided allocator. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttribute_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name, + _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out); + + /// @} + /// \name OrtSessionOptions + /// Custom operator APIs + /// @{ + + /** \brief Checks if the given session configuration entry exists. + * + * The config_key formats are defined in onnxruntime_session_options_config_keys.h + * + * Can be used in a custom operator library to check for session configuration entries + * that target one or more custom operators in the library. Example: The config entry + * custom_op.myop.some_key targets a custom op named "myop". + * + * \param[in] options The ::OrtSessionOptions instance. + * \param[in] config_key A null-terminated UTF-8 string representation of the configuration key. + * \param[out] out Pointer set to 1 if the entry exists and 0 otherwise. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(HasSessionConfigEntry, _In_ const OrtSessionOptions* options, + _In_z_ const char* config_key, _Out_ int* out); + + /** \brief Get a session configuration value. + * + * Returns a failure status if the configuration key does not exist. + * The config_key and the format of config_value are defined in onnxruntime_session_options_config_keys.h + * + * If `config_value` is nullptr, the value of `size` is set to the true size of the string + * value (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the actual string value's size, + * the value of `size` is set to the true size of the string value, the provided memory + * is filled with the value's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string value's size and `config_value` + * is not nullptr, the value of `size` is set to the true size of the string value + * and a failure status is returned. + * + * Can be used in a custom operator library to get session configuration entries + * that target one or more custom operators in the library. Example: The config entry + * custom_op.myop.some_key targets a custom op named "myop". + * + * \param[in] options The session options. + * \param[in] config_key A null-terminated UTF-8 string representation of the config key. + * \param[in] config_value Pointer to memory where the null-terminated UTF-8 string value will be stored. + * \param[in,out] size Pointer to the size of the `config_value` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(GetSessionConfigEntry, _In_ const OrtSessionOptions* options, + _In_z_ const char* config_key, _Out_ char* config_value, _Inout_ size_t* size); + + /// @} + + /** \brief Append dnnl provider to session options + * + * If oneDNN is not available, this function will return failure. + * + * \param[in] options + * \param[in] dnnl_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_Dnnl, + _In_ OrtSessionOptions* options, _In_ const OrtDnnlProviderOptions* dnnl_options); + + /** \brief Create an OrtDnnlProviderOptions + * + * \param[out] out Newly created ::OrtDnnlProviderOptions. Must be released with OrtApi::ReleaseDnnlProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(CreateDnnlProviderOptions, _Outptr_ OrtDnnlProviderOptions** out); + + /** \brief Set options in a oneDNN Execution Provider. + * + * Key should be in null terminated string format of the member of ::OrtDnnlProviderOptions + * and value should be its related range. + * + * For example, key="use_arena" and value="1" + * + * \param[in] dnnl_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(UpdateDnnlProviderOptions, _Inout_ OrtDnnlProviderOptions* dnnl_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** + * Get serialized oneDNN provider options string. + * + * For example, "use_arena=1;......" + * + * \param dnnl_options - OrtDnnlProviderOptions instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(GetDnnlProviderOptionsAsString, _In_ const OrtDnnlProviderOptions* dnnl_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an ::OrtDnnlProviderOptions + * + * \since Version 1.15. + */ + void(ORT_API_CALL* ReleaseDnnlProviderOptions)(_Frees_ptr_opt_ OrtDnnlProviderOptions* input); + + /// \name OrtKernelInfo + /// Custom operator APIs. + /// @{ + + /** \brief Get the graph node name from ::OrtKernelInfo. + * + * If `out` is nullptr, the value of `size` is set to the size of the name + * string (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the name string's size, + * the value of `size` is set to the true size of the string (including null-terminator), + * the provided memory is filled with the string's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string + * and a failure status is returned. + * + * Can be used in a custom operator's CreateKernel callback to get the name of the operator's node name in the graph. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the name. + * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.15 + */ + ORT_API2_STATUS(KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_ char* out, _Inout_ size_t* size); + + /** \brief Get the session logger from ::OrtKernelInfo. + * + * Used in the CreateKernel callback of an OrtCustomOp to get a logger that can be used to log + * messages. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[out] logger Pointer set to the session's ::OrtLogger. Owned by ONNX Runtime, so do not free. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.15 + */ + ORT_API2_STATUS(KernelInfo_GetLogger, _In_ const OrtKernelInfo* info, _Outptr_ const OrtLogger** logger); + + /// @} + /// \name OrtKernelContext + /// Custom operator APIs. + /// @{ + + /** \brief Get the runtime logger from ::OrtKernelContext. + * + * Used in the KernelCompute callback of an OrtCustomOp to get a logger that can be used to log + * messages during inference. + * + * \param[in] context An instance of ::OrtKernelContext. + * \param[out] logger Pointer set to the kernel context's ::OrtLogger. Owned by ONNX Runtime, so do not free. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.15 + */ + ORT_API2_STATUS(KernelContext_GetLogger, _In_ const OrtKernelContext* context, _Outptr_ const OrtLogger** logger); + + /// @} + /// \name OrtLogger + /// Custom operator APIs. + /// @{ + + /** \brief Logs a message at the given severity level using the provided ::OrtLogger. + * + * Only messages with a severity level equal or greater than the ::OrtLogger's logging severity level + * are logged. Use OrtApi::Logger_GetLoggingSeverityLevel to get the ::OrtLogger's logging severity + * level. + * + * Can be used in custom operators to log messages with the logger retrieved via OrtApi::KernelInfo_GetLogger. + * + * \param[in] logger The ::OrtLogger instance. + * \param[in] log_severity_level The message's severity level. + * \param[in] message The message to log. + * \param[in] file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. + * \param[in] line_number The file line number in which the message is logged. Usually the value of __LINE__. + * \param[in] func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.15 + */ + ORT_API2_STATUS(Logger_LogMessage, _In_ const OrtLogger* logger, OrtLoggingLevel log_severity_level, + _In_z_ const char* message, _In_z_ const ORTCHAR_T* file_path, int line_number, + _In_z_ const char* func_name); + + /** \brief Get the logging severity level of the ::OrtLogger. + * + * Can be used in a custom operator to get the logging serverity level of the ::OrtLogger associated with + * the ::OrtKernelInfo. + * + * \param[in] logger The ::OrtLogger instance. + * \param[out] out Pointer to variable assigned with the logging severity level on success. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.15 + */ + ORT_API2_STATUS(Logger_GetLoggingSeverityLevel, _In_ const OrtLogger* logger, _Out_ OrtLoggingLevel* out); + + /// @} + + /** \brief Get a ::OrtValue tensor stored as a constant initializer in the graph node. + * + * Used in the CreateKernel callback of an OrtCustomOp to get a tensor value. + * + * \param[in] info ::OrtKernelInfo instance. + * \param[in] index The node index. + * \param[out] is_constant Is it a constant node input or not. + * \param[out] out The OrtValue tensor value. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out); + + /** \brief Get Optional Type information from an ::OrtTypeInfo + * + * This augments ::OrtTypeInfo to return an ::OrtOptionalTypeInfo when the type is optional. + * The OrtOptionalTypeInfo also has a nested ::OrtTypeInfo that describes the type of the optional value. + * ::OrtOptionalTypeInfo type can only appear within model metadata to describe inputs/outputs. + * The actual OrtValues that are supplied in place of optional type inputs should contain + * specific type that is described by ::OrtOptionalTypeInfo. + * + * So the picture: ::OrtTypeInfo -> ::OrtOptionalTypeInfo -> ::OrtTypeInfo (describes the type that can be supplied + * in place of the optional type when creating the actual ::OrtValue). + * + * \param[in] type_info + * \param[out] out A pointer to the ::OrtOptionalTypeInfo. Do not free this value, + * it is owned by OrtTypeInfo instance. When the type_info does not represent + * optional type, nullptr is returned in out. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtOptionalTypeInfo** out); + + /** \brief Get OrtTypeInfo for the allowed contained type from an ::OrtOptionalTypeInfo. + * + * This augments ::OrtOptionalTypeInfo to return an ::OrtTypeInfo for the contained type. + * The OrtOptionalTypeInfo has a nested ::OrtTypeInfo that describes the type of the optional value. + * ::OrtOptionalTypeInfo type can only appear within model metadata to describe inputs/outputs. + * The actual OrtValues that are supplied in place of optional type inputs should contain + * specific type that is described by the returned ::OrtTypeInfo. + * + * \param[in] optional_type_info + * \param[out] out A pointer to the ::OrtTypeInfo for what the optional value could be. + * it is owned by OrtOptionalTypeInfo instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(GetOptionalContainedTypeInfo, _In_ const OrtOptionalTypeInfo* optional_type_info, + _Outptr_ OrtTypeInfo** out); + + /** \brief Set a single string in a string tensor + * Do not zero terminate the string data. + * + * \param[in] value A string tensor + * \param[in] index - flat index of the element + * \param[in] length_in_bytes length of the buffer in utf-8 bytes (without the null terminator) + * \param[inout] buffer - address of return value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetResizedStringTensorElementBuffer, _Inout_ OrtValue* value, _In_ size_t index, _In_ size_t length_in_bytes, _Inout_ char** buffer); + + /** \brief Get Allocator from KernelContext for a specific memoryInfo. Please use C API ReleaseAllocator to release out object + * + * \param[in] context OrtKernelContext instance + * \param[in] mem_info OrtMemoryInfo instance + * \param[out] out A pointer to OrtAllocator. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out); + + /** \brief Returns a null terminated string of the build info including git info and cxx flags + * + * \return UTF-8 encoded version string. Do not deallocate the returned buffer. + * + * \since Version 1.15. + */ + const char*(ORT_API_CALL* GetBuildInfoString)(void); + + /// \name OrtROCMProviderOptions + /// @{ + + /** \brief Create an OrtROCMProviderOptions + * + * \param[out] out Newly created ::OrtROCMProviderOptions. Must be released with OrtApi::ReleaseROCMProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.16. + */ + ORT_API2_STATUS(CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out); + + /** \brief Set options in a ROCm Execution Provider. + * + * Please refer to https://onnxruntime.ai/docs/execution-providers/ROCm-ExecutionProvider.html + * to know the available keys and values. Key should be in null terminated string format of the member of + * ::OrtROCMProviderOptions and value should be its related range. + * + * For example, key="device_id" and value="0" + * + * \param[in] rocm_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.16. + */ + ORT_API2_STATUS(UpdateROCMProviderOptions, _Inout_ OrtROCMProviderOptions* rocm_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** + * Get serialized ROCm provider options string. + * + * For example, "device_id=0;arena_extend_strategy=0;......" + * + * \param rocm_options - OrtROCMProviderOptions instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.16. + */ + ORT_API2_STATUS(GetROCMProviderOptionsAsString, _In_ const OrtROCMProviderOptions* rocm_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an ::OrtROCMProviderOptions + * + * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does + * + * \since Version 1.16. + */ + void(ORT_API_CALL* ReleaseROCMProviderOptions)(_Frees_ptr_opt_ OrtROCMProviderOptions* input); + + /** \brief Create an allocator with specific type and register it with the ::OrtEnv + * This API enhance CreateAndRegisterAllocator that it can create an allocator with specific type, not just CPU allocator + * Enables sharing the allocator between multiple sessions that use the same env instance. + * Lifetime of the created allocator will be valid for the duration of the environment. + * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. + * \param[in] env OrtEnv instance + * \param[in] provider_type ExecutionProvider type + * \param[in] mem_info OrtMemoryInfo instance + * \param[in] arena_cfg Arena configuration + * \param[in] provider_options_keys key of the provider options map + * \param[in] provider_options_values value of the provider options map + * \param[in] num_keys Length of the provider options map + */ + ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg, + _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + + /** \brief Run the model asynchronously in a thread owned by intra op thread pool + * + * \param[in] session + * \param[in] run_options If nullptr, will use a default ::OrtRunOptions + * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names + * \param[in] input Array of ::OrtValue%s of the input values + * \param[in] input_len Number of elements in the input_names and inputs arrays + * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names + * \param[in] output_names_len Number of elements in the output_names and outputs array + * \param[out] output OrtValue* array of size output_names_len. + * On calling RunAsync, output[i] could either be a null or a pointer to a preallocated OrtValue. + * Later, the output array will be passed to run_async_callback with all null(s) filled with valid + * OrtValue pointer(s) allocated by onnxruntime. + * NOTE: it is customer's duty to finally release the output array and each of its member, + * regardless of whether the member (OrtValue*) is allocated by onnxruntime or preallocated by the customer. + * \param[in] run_async_callback Callback function on model run completion + * \param[in] user_data User data that pass back to run_async_callback + */ + ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options, + _In_reads_(input_len) const char* const* input_names, + _In_reads_(input_len) const OrtValue* const* input, size_t input_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _Inout_updates_all_(output_names_len) OrtValue** output, + _In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data); + + /** + * Update TensorRT EP provider option where its data type is pointer, for example 'user_compute_stream'. + * If the data type of the provider option can be represented by string please use UpdateTensorRTProviderOptions. + * + * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer. + * + * \param tensorrt_options - OrtTensorRTProviderOptionsV2 instance + * \param key - Name of the provider option + * \param value - A pointer to the instance that will be assigned to this provider option + * + * \since Version 1.16. + */ + ORT_API2_STATUS(UpdateTensorRTProviderOptionsWithValue, _Inout_ OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _In_ void* value); + + /** + * Get TensorRT EP provider option where its data type is pointer. + * If the data type of the provider option can be represented by string please use GetTensorRTProviderOptionsAsString. + * + * \param tensorrt_options - OrtTensorRTProviderOptionsV2 instance + * \param key - Name of the provider option + * \param ptr - A pointer to the instance that is kept by the provider option + * + * \since Version 1.16. + */ + ORT_API2_STATUS(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _Outptr_ void** ptr); + + /** + * Update CUDA EP provider option where its data type is pointer, for example 'user_compute_stream'. + * If the data type of the provider option can be represented by string please use UpdateCUDAProviderOptions. + * + * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer. + * + * \param cuda_options - OrtCUDAProviderOptionsV2 instance + * \param key - Name of the provider option + * \param value - A pointer to the instance that will be assigned to this provider option + * + * \since Version 1.16. + */ + ORT_API2_STATUS(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _In_ void* value); + + /** + * Get CUDA EP provider option where its data type is pointer. + * If the data type of the provider option can be represented by string please use GetCUDAProviderOptionsAsString. + * + * \param cuda_options - OrtCUDAProviderOptionsV2 instance + * \param key - Name of the provider option + * \param ptr - A pointer to the instance that is kept by the provider option + * + * \since Version 1.16. + */ + ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr); + + /** + * Get a EP resource. + * E.g. a cuda stream or a cublas handle + * + * \param context - Kernel context + * \param resource_version - Version of the resource + * \param resource_id - Type of resource + * \param resource - A pointer to returned resource + * + * \since Version 1.16. + */ + ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resource_version, + _In_ int resource_id, _Outptr_ void** resource); + + /** \brief Set user logging function + * + * By default the logger created by the CreateEnv* functions is used to create the session logger as well. + * This function allows a user to override this default session logger with a logger of their own choosing. This way + * the user doesn't have to create a separate environment with a custom logger. This addresses the problem when + * the user already created an env but now wants to use a different logger for a specific session (for debugging or + * other reasons). + * + * \param[in] options + * \param[in] user_logging_function A pointer to a logging function. + * \param[in] user_logging_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to + * `user_logging_function`. This parameter is optional. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options, + _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param); + + /** + * Get number of input from OrtShapeInferContext + * + * \param[in] context + * \param[out] out The number of inputs + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, _Out_ size_t* out); + + /** + * Get type and shape info of an input + * + * \param[in] context + * \param[in] index The index of the input + * \param[out] info Type shape info of the input + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info); + + /** + * Get attribute from OrtShapeInferContext. Note that OrtShapeInferContext is a per-node context, one could only read attribute from current node. + * + * \param[in] context + * \param[in] attr_name Name of the attribute + * \param[out] attr Handle of the attribute fetched + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr); + + /** + * Set type and shape info of an output + * + * \param[in] context + * \param[in] index The index of the output + * \param[out] info Type shape info of the output + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info); + + /** + * Set symbolic shape to type shape info + * + * \param[in] info Type shape info + * \param[in] dim_params Symbolic strings + * \param[in] dim_params_length Number of strings + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetSymbolicDimensions, _In_ OrtTensorTypeAndShapeInfo* info, _In_ const char* dim_params[], _In_ size_t dim_params_length); + + /** + * Read contents of an attribute to data + * + * \param[in] op_attr + * \param[in] type Attribute type + * \param[out] data Memory address to save raw content of the attribute + * \param[in] len Number of bytes allowed to store in data + * \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); + + /** \brief Set whether to use deterministic compute. + * + * Default is false. If set to true, this will enable deterministic compute for GPU kernels where possible. + * Note that this most likely will have a performance cost. + * + * \param[in] options + * \param[in] value + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value); + + /** + * Run fn in parallel + * + * \param[in] context + * \param[in] fn Function accepting usr_data and an integer as iterator + * \param[in] total The number of times fn is to be invoked + * \param[in] num_batch Number of batches by which the "total" is to be divided in maximum. When zero, there is no limit + * \param[in] usr_data User data to be passed back to fn + * + * \since Version 1.17. + */ + ORT_API2_STATUS(KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data); + + /** \brief Append OpenVINO execution provider to the session options + * + * If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail. + * + * \param[in] options + * \param[in] provider_options_keys + * \param[in] provider_options_values + * \param[in] num_keys + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2, + _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** \brief Append VitisAI provider to session options + * + * If VitisAI is not available (due to a non VitisAI enabled build, or if VitisAI is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] provider_options_keys + * \param[in] provider_options_values + * \param[in] num_keys + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** \brief Get scratch buffer from the corresponding allocator under the sepcific OrtMemoryInfo object. + * NOTE: callers are responsible to release this scratch buffer from the corresponding allocator + * \param[in] context OrtKernelContext instance + * \param[in] mem_info OrtMemoryInfo instance + * \param[in] count_or_bytes How many bytes is this scratch buffer + * \param[out] out A pointer to the scrach buffer + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); + + /** \brief Get allocator from KernelInfo for a specific memory type. Please use C API ReleaseAllocator to release out object + * + * \param[in] info OrtKernelInfo instance + * \param[in] mem_type OrtMemType object + * \param[out] out A pointer to OrtAllocator + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); + + /** \brief Replace initialized Tensors with external data with the provided files in memory + * + * The function will find the initialized TensorProtos with external data in the graph with the provided + * external file names and the file content in memory. The API gets the external file name, offset, data length + * from TensorProto, and locate the tensor data from the file in memory buffer. + * It creates a Tensor to replace the existing Tensor in graph. The replacement + * will occur before any of the optimizations take place. The data will be copied into the graph + * since TensorProto can't refer to the user provided buffers. + * + * \param[in] options + * \param[in] external_initializer_file_names Array of null terminated UTF-8 encoded strings of the file names + * which holds the external initializers. + * \param[in] external_initializer_file_buffer_array Array of pointers to the buffer of the file content. + * The buffer can be freed after session creation. + * \param[in] external_initializer_file_lengths Array of size_t to indicate the length of file content + * \param[in] num_external_initializer_files Number of external files + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(AddExternalInitializersFromFilesInMemory, _In_ OrtSessionOptions* options, + _In_reads_(num_external_initializer_files) const ORTCHAR_T* const* external_initializer_file_names, + _In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, + _In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, + size_t num_external_initializer_files); + + /** \brief Create an OrtLoraAdapter + * + * The function attempts to locate file specified by adapter_file_path, read it and create an OrtLoraAdapter + * instance. The adapter_file_path should be a valid path to a file that contains a valid Lora Adapter + * format. The function attempts to validate the format at load time. The file will always be memory mapped, unless + * the platform does not support memory mapping, in which case the file will be read into memory. + * + * \param[in] adapter_file_path adapter file path. + * \param[in] allocator optional pointer to a device allocator. If specified + * data is copied to the device at some point before Run() is invoked. If nullptr, data stays on CPU. + * The data would still be copied to device if required by the model at inference time. + * \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with + * OrtApi::ReleaseLoraAdapter. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator, + _Outptr_ OrtLoraAdapter** out); + + /** \brief Create an OrtLoraAdapter + * + * The function copies the bytes from the array and creates an OrtLoraAdapter instance. + * + * + * \param[in] bytes pointer to a valid Lora Adapter format buffer. + * \param[in] num_bytes length of bytes buffer. + * \param[in] allocator optional pointer to a device allocator. If specified + * data is copied to the device at some point before Run() is invoked. If nullptr, data stays on CPU. + * The data would still be copied to device if required by the model at inference time. + * \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with + * OrtApi::ReleaseLoraAdapter. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator, + _Outptr_ OrtLoraAdapter** out); + + /** \brief Release an ::OrtLoraAdapter obtained from OrtApi::CreateLoraAdapter + */ + ORT_CLASS_RELEASE(LoraAdapter); + + /** \brief Add the Lora Adapter to the list of active adapters. + * + * The function adds the Lora Adapter to the list of active adapters. The Lora Adapter must be created with + * OrtApi::CreateLoraAdapter or FromArray. The Lora Adapter will be used by the session to run the model. + * The instance of the OrtRunOptions can then be used to customize the Run() calls. + * More than one OrtLoraAdapter can be active at the same time. Lora Parameters that belong to different + * Lora adapters that will be active at the same time must not overlap. + * This setting does not affect RunWithBinding. + * + * \param[in] options OrtRunOptions instance + * \param[in] adapter OrtLoraAdapter instance + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter); + + /// @} + /// \name OrtEpDynamicOptions + /// @{ + + /** \brief Set DynamicOptions for EPs (Execution Providers) + * + * Valid options can be found in `include\onnxruntime\core\session\onnxruntime_session_options_config_keys.h` + * Look for `kOrtEpDynamicOptions` + * + * \param[in] sess OrtSession + * \param[in] keys Array of null terminated UTF8 encoded strings of EP dynamic option keys + * \param[in] values Array of null terminated UTF8 encoded string of EP dynamic option values + * \param[in] kv_len Number of elements in the keys and values arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, + _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); +}; + +/* + * Steps to use a custom op: + * 1 Create an OrtCustomOpDomain with the domain name used by the custom ops + * 2 Create an OrtCustomOp structure for each op and add them to the domain + * 3 Call OrtAddCustomOpDomain to add the custom domain of ops to the session options + */ + +// Specifies some characteristics of inputs/outputs of custom ops: +// Specify if the inputs/outputs are one of: +// 1) Non-optional (input/output must be present in the node) +// 2) Optional (input/output may be absent in the node) +// 3) Variadic: A variadic input or output specifies N (i.e., the minimum arity) or more operands. +// Only the last input or output of a custom op may be marked as variadic. +// The homogeneity of the variadic input or output determines whether all operands must be of the same +// tensor element type. +typedef enum OrtCustomOpInputOutputCharacteristic { + INPUT_OUTPUT_REQUIRED = 0, + INPUT_OUTPUT_OPTIONAL, + INPUT_OUTPUT_VARIADIC, +} OrtCustomOpInputOutputCharacteristic; + +/* + * The OrtCustomOp structure defines a custom op's schema and its kernel callbacks. The callbacks are filled in by + * the implementor of the custom op. + */ +struct OrtCustomOp { + uint32_t version; // Must be initialized to ORT_API_VERSION + + // This callback creates the kernel, which is a user defined + // parameter that is passed to the Kernel* callbacks below. It is + // recommended to use CreateKernelV2 which allows for a safe error + // propagation by returning an OrtStatusPtr. + void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, + _In_ const OrtKernelInfo* info); + + // Returns the name of the op + const char*(ORT_API_CALL* GetName)(_In_ const struct OrtCustomOp* op); + + // Returns the type of the execution provider, return nullptr to use CPU execution provider + const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ const struct OrtCustomOp* op); + + // Returns the count and types of the input & output tensors + ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + size_t(ORT_API_CALL* GetInputTypeCount)(_In_ const struct OrtCustomOp* op); + ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op); + + // Perform a computation step. It is recommended to use + // KernelComputeV2 which allows for a safe error propagation by + // returning an OrtStatusPtr. + void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context); + void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); + + // Returns the characteristics of the input & output tensors + OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + + // Returns the memory type of the input tensors. This API allows the custom op + // to place the inputs on specific devices. By default, it returns + // OrtMemTypeDefault, which means the input is placed on the default device for + // the execution provider. If the inputs need to be with different memory tyeps, + // this function can be overridden to return the specific memory types. + OrtMemType(ORT_API_CALL* GetInputMemoryType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + + // Returns the minimum number of input arguments expected for the variadic input. + // Applicable only for custom ops that have a variadic input. + int(ORT_API_CALL* GetVariadicInputMinArity)(_In_ const struct OrtCustomOp* op); + + // Returns true (non-zero) if all arguments of a variadic input have to be of the same type (homogeneous), + // and false (zero) otherwise. + // Applicable only for custom ops that have a variadic input. + int(ORT_API_CALL* GetVariadicInputHomogeneity)(_In_ const struct OrtCustomOp* op); + + // Returns the minimum number of output values expected for the variadic output. + // Applicable only for custom ops that have a variadic output. + int(ORT_API_CALL* GetVariadicOutputMinArity)(_In_ const struct OrtCustomOp* op); + + // Returns true (non-zero) if all outputs values of a variadic output have to be of the same type (homogeneous), + // and false (zero) otherwise. + // Applicable only for custom ops that have a variadic output. + int(ORT_API_CALL* GetVariadicOutputHomogeneity)(_In_ const struct OrtCustomOp* op); + + // Create the kernel state which is passed to each compute call. + OrtStatusPtr(ORT_API_CALL* CreateKernelV2)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, + _In_ const OrtKernelInfo* info, + _Out_ void** kernel); + + // Perform the computation step. + OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context); + + OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*); + + // Get start range + int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); + int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); + + // Get the inplace_map that defines which output can reuse which input + // Callers will provide 2 raw int* and pass in their address, this function will fill these 2 arrays + // when return, output (*output_index)[i] may reuse the input (*input_index[i]). + // The return value is the size of these 2 arrays. + // Callers are responsible to delete these 2 arrays after use by calling OrtCustomOp::ReleaseMayInplace(). + size_t(ORT_API_CALL* GetMayInplace)(_Out_ int** input_index, _Out_ int** output_index); + + // Release the pointer input_index and output_index allocated from GetMayInplace() function. + // If GetMayInplace() is defined, this function MUST be defined as well. + void(ORT_API_CALL* ReleaseMayInplace)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); + + // Same as GetMayInplace() and ReleaseMayInplace() + size_t(ORT_API_CALL* GetAliasMap)(_Out_ int** input_index, _Out_ int** output_index); + void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); +}; /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists * * \param device_id CUDA device id, starts from zero. -*/ + */ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id); +/* + * This is the old way to add the ROCm provider to the session, please use + * SessionOptionsAppendExecutionProvider_ROCM above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with + * HIP support and the ROCm provider shared library exists + * + * \param device_id HIP device id, starts from zero. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessionOptions* options, int device_id); + +/* + * This is the old way to add the MIGraphX provider to the session, please use + * SessionOptionsAppendExecutionProvider_MIGraphX above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with + * HIP support and the MIGraphX provider shared library exists + * + * \param device_id HIP device id, starts from zero. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, int device_id); + +/* + * This is the old way to add the oneDNN provider to the session, please use + * SessionOptionsAppendExecutionProvider_oneDNN above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with + * oneDNN support and the oneDNN provider shared library exists + * + * \param use_arena zero: false. non-zero: true. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena); + +/* + * This is the old way to add the TensorRT provider to the session, please use SessionOptionsAppendExecutionProvider_TensorRT_V2 above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with TensorRT support and the TensorRT provider shared library exists + * + * \param device_id CUDA device id, starts from zero. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); + #ifdef __cplusplus } #endif - -//! @} +/// @} diff --git a/libs/onnxruntime/include/onnxruntime_cxx_api.h b/libs/onnxruntime/include/onnxruntime_cxx_api.h index fb8f481..ff196cf 100644 --- a/libs/onnxruntime/include/onnxruntime_cxx_api.h +++ b/libs/onnxruntime/include/onnxruntime_cxx_api.h @@ -4,22 +4,36 @@ // Summary: The Ort C++ API is a header only wrapper around the Ort C API. // // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors -// and automatically releasing resources in the destructors. +// and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so +// all the resources follow RAII and do not leak memory. // // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers. -// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). +// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them +// until you assign an instance that actually holds an underlying object. // -// Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone' -// methods for this purpose. +// For Ort objects only move assignment between objects is allowed, there are no copy constructors. +// Some objects have explicit 'Clone' methods for this purpose. +// +// ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments +// by value or by reference. ConstXXXX types are restricted to const only interfaces. +// +// UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces. +// +// The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not +// have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code. #pragma once #include "onnxruntime_c_api.h" +#include "onnxruntime_float16.h" + #include +#include #include #include #include #include #include +#include #include #include @@ -28,14 +42,14 @@ #endif /** \brief All C++ Onnxruntime APIs are defined inside this namespace -* -*/ + * + */ namespace Ort { /** \brief All C++ methods that can fail will throw an exception of this type -* -* If ORT_NO_EXCEPTIONS is defined, then any error will result in a call to abort() -*/ + * + * If ORT_NO_EXCEPTIONS is defined, then any error will result in a call to abort() + */ struct Exception : std::exception { Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} @@ -48,6 +62,9 @@ struct Exception : std::exception { }; #ifdef ORT_NO_EXCEPTIONS +// The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors. +// NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW +#ifndef ORT_CXX_API_THROW #define ORT_CXX_API_THROW(string, code) \ do { \ std::cerr << Ort::Exception(string, code) \ @@ -55,12 +72,14 @@ struct Exception : std::exception { << std::endl; \ abort(); \ } while (false) +#endif #else #define ORT_CXX_API_THROW(string, code) \ throw Ort::Exception(string, code) #endif -// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, it's in a template so that we can define a global variable in a header and make +// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, +// it's in a template so that we can define a global variable in a header and make // it transparent to the users of the API. template struct Global { @@ -71,17 +90,413 @@ struct Global { template #ifdef ORT_API_MANUAL_INIT const OrtApi* Global::api_{}; -inline void InitApi() { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } +inline void InitApi() noexcept { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } + +// Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is +// required by C++ APIs. +// +// Example mycustomop.cc: +// +// #define ORT_API_MANUAL_INIT +// #include +// #undef ORT_API_MANUAL_INIT +// +// OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) { +// Ort::InitApi(api_base->GetApi(ORT_API_VERSION)); +// // ... +// } +// +inline void InitApi(const OrtApi* api) noexcept { Global::api_ = api; } #else +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers. +// Please define ORT_API_MANUAL_INIT if it conerns you. +#pragma warning(disable : 26426) +#endif const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif #endif /// This returns a reference to the OrtApi interface in use -inline const OrtApi& GetApi() { return *Global::api_; } - -/// This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representing the available execution providers. +inline const OrtApi& GetApi() noexcept { return *Global::api_; } + +/// +/// This function returns the onnxruntime version string +/// +/// version string major.minor.rev +std::string GetVersionString(); + +/// +/// This function returns the onnxruntime build information: including git branch, +/// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags. +/// +/// string +std::string GetBuildInfoString(); + +/// +/// This is a C++ wrapper for OrtApi::GetAvailableProviders() and +/// returns a vector of strings representing the available execution providers. +/// +/// vector of strings std::vector GetAvailableProviders(); +/** \brief IEEE 754 half-precision floating point data type + * + * \details This struct is used for converting float to float16 and back + * so the user could feed inputs and fetch outputs using these type. + * + * The size of the structure should align with uint16_t and one can freely cast + * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. + * + * \code{.unparsed} + * // This example demonstrates converion from float to float16 + * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; + * std::vector fp16_values; + * fp16_values.reserve(std::size(values)); + * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values), + * [](float value) { return Ort::Float16_t(value); }); + * + * \endcode + */ +struct Float16_t : onnxruntime_float16::Float16Impl { + private: + /// + /// Constructor from a 16-bit representation of a float16 value + /// No conversion is done here. + /// + /// 16-bit representation + constexpr explicit Float16_t(uint16_t v) noexcept { val = v; } + + public: + using Base = onnxruntime_float16::Float16Impl; + + /// + /// Default constructor + /// + Float16_t() = default; + + /// + /// Explicit conversion to uint16_t representation of float16. + /// + /// uint16_t bit representation of float16 + /// new instance of Float16_t + constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); } + + /// + /// __ctor from float. Float is converted into float16 16-bit representation. + /// + /// float value + explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); } + + /// + /// Converts float16 to float + /// + /// float representation of float16 value + float ToFloat() const noexcept { return Base::ToFloatImpl(); } + + /// + /// Checks if the value is negative + /// + /// true if negative + using Base::IsNegative; + + /// + /// Tests if the value is NaN + /// + /// true if NaN + using Base::IsNaN; + + /// + /// Tests if the value is finite + /// + /// true if finite + using Base::IsFinite; + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + using Base::IsPositiveInfinity; + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + using Base::IsNegativeInfinity; + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + using Base::IsInfinity; + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + using Base::IsNaNOrZero; + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + using Base::IsNormal; + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + using Base::IsSubnormal; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + using Base::Abs; + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + using Base::Negate; + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + using Base::AreZero; + + /// + /// User defined conversion operator. Converts Float16_t to float. + /// + explicit operator float() const noexcept { return ToFloat(); } + + using Base::operator==; + using Base::operator!=; + using Base::operator<; +}; + +static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); + +/** \brief bfloat16 (Brain Floating Point) data type + * + * \details This struct is used for converting float to bfloat16 and back + * so the user could feed inputs and fetch outputs using these type. + * + * The size of the structure should align with uint16_t and one can freely cast + * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. + * + * \code{.unparsed} + * // This example demonstrates converion from float to float16 + * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; + * std::vector bfp16_values; + * bfp16_values.reserve(std::size(values)); + * std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values), + * [](float value) { return Ort::BFloat16_t(value); }); + * + * \endcode + */ +struct BFloat16_t : onnxruntime_float16::BFloat16Impl { + private: + /// + /// Constructor from a uint16_t representation of bfloat16 + /// used in FromBits() to escape overload resolution issue with + /// constructor from float. + /// No conversion is done. + /// + /// 16-bit bfloat16 value + constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; } + + public: + using Base = onnxruntime_float16::BFloat16Impl; + + BFloat16_t() = default; + + /// + /// Explicit conversion to uint16_t representation of bfloat16. + /// + /// uint16_t bit representation of bfloat16 + /// new instance of BFloat16_t + static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); } + + /// + /// __ctor from float. Float is converted into bfloat16 16-bit representation. + /// + /// float value + explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); } + + /// + /// Converts bfloat16 to float + /// + /// float representation of bfloat16 value + float ToFloat() const noexcept { return Base::ToFloatImpl(); } + + /// + /// Checks if the value is negative + /// + /// true if negative + using Base::IsNegative; + + /// + /// Tests if the value is NaN + /// + /// true if NaN + using Base::IsNaN; + + /// + /// Tests if the value is finite + /// + /// true if finite + using Base::IsFinite; + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + using Base::IsPositiveInfinity; + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + using Base::IsNegativeInfinity; + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + using Base::IsInfinity; + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + using Base::IsNaNOrZero; + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + using Base::IsNormal; + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + using Base::IsSubnormal; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + using Base::Abs; + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + using Base::Negate; + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + using Base::AreZero; + + /// + /// User defined conversion operator. Converts BFloat16_t to float. + /// + explicit operator float() const noexcept { return ToFloat(); } + + // We do not have an inherited impl for the below operators + // as the internal class implements them a little differently + bool operator==(const BFloat16_t& rhs) const noexcept; + bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); } + bool operator<(const BFloat16_t& rhs) const noexcept; +}; + +static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); + +/** \brief float8e4m3fn (Float8 Floating Point) data type + * \details It is necessary for type dispatching to make use of C++ API + * The type is implicitly convertible to/from uint8_t. + * See https://onnx.ai/onnx/technical/float8.html for further details. + */ +struct Float8E4M3FN_t { + uint8_t value; + constexpr Float8E4M3FN_t() noexcept : value(0) {} + constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {} + constexpr operator uint8_t() const noexcept { return value; } + // nan values are treated like any other value for operator ==, != + constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; }; + constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; }; +}; + +static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match"); + +/** \brief float8e4m3fnuz (Float8 Floating Point) data type + * \details It is necessary for type dispatching to make use of C++ API + * The type is implicitly convertible to/from uint8_t. + * See https://onnx.ai/onnx/technical/float8.html for further details. + */ +struct Float8E4M3FNUZ_t { + uint8_t value; + constexpr Float8E4M3FNUZ_t() noexcept : value(0) {} + constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {} + constexpr operator uint8_t() const noexcept { return value; } + // nan values are treated like any other value for operator ==, != + constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; }; + constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; }; +}; + +static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match"); + +/** \brief float8e5m2 (Float8 Floating Point) data type + * \details It is necessary for type dispatching to make use of C++ API + * The type is implicitly convertible to/from uint8_t. + * See https://onnx.ai/onnx/technical/float8.html for further details. + */ +struct Float8E5M2_t { + uint8_t value; + constexpr Float8E5M2_t() noexcept : value(0) {} + constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {} + constexpr operator uint8_t() const noexcept { return value; } + // nan values are treated like any other value for operator ==, != + constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; }; + constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; }; +}; + +static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match"); + +/** \brief float8e5m2fnuz (Float8 Floating Point) data type + * \details It is necessary for type dispatching to make use of C++ API + * The type is implicitly convertible to/from uint8_t. + * See https://onnx.ai/onnx/technical/float8.html for further details. + */ +struct Float8E5M2FNUZ_t { + uint8_t value; + constexpr Float8E5M2FNUZ_t() noexcept : value(0) {} + constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {} + constexpr operator uint8_t() const noexcept { return value; } + // nan values are treated like any other value for operator ==, != + constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; }; + constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; }; +}; + +static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match"); + +namespace detail { // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type // This can't be done in the C API since C doesn't have function overloading. #define ORT_DEFINE_RELEASE(NAME) \ @@ -90,8 +505,10 @@ std::vector GetAvailableProviders(); ORT_DEFINE_RELEASE(Allocator); ORT_DEFINE_RELEASE(MemoryInfo); ORT_DEFINE_RELEASE(CustomOpDomain); +ORT_DEFINE_RELEASE(ThreadingOptions); ORT_DEFINE_RELEASE(Env); ORT_DEFINE_RELEASE(RunOptions); +ORT_DEFINE_RELEASE(LoraAdapter); ORT_DEFINE_RELEASE(Session); ORT_DEFINE_RELEASE(SessionOptions); ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); @@ -100,148 +517,185 @@ ORT_DEFINE_RELEASE(MapTypeInfo); ORT_DEFINE_RELEASE(TypeInfo); ORT_DEFINE_RELEASE(Value); ORT_DEFINE_RELEASE(ModelMetadata); -ORT_DEFINE_RELEASE(ThreadingOptions); ORT_DEFINE_RELEASE(IoBinding); ORT_DEFINE_RELEASE(ArenaCfg); +ORT_DEFINE_RELEASE(Status); +ORT_DEFINE_RELEASE(OpAttr); +ORT_DEFINE_RELEASE(Op); +ORT_DEFINE_RELEASE(KernelInfo); #undef ORT_DEFINE_RELEASE -/** \brief IEEE 754 half-precision floating point data type - * \details It is necessary for type dispatching to make use of C++ API - * The type is implicitly convertible to/from uint16_t. - * The size of the structure should align with uint16_t and one can freely cast - * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. - * - * Generally, you can feed any of your types as float16/blfoat16 data to create a tensor - * on top of it, providing it can form a continuous buffer with 16-bit elements with no padding. - * And you can also feed a array of uint16_t elements directly. For example, - * - * \code{.unparsed} - * uint16_t values[] = { 15360, 16384, 16896, 17408, 17664}; - * constexpr size_t values_length = sizeof(values) / sizeof(values[0]); - * std::vector dims = {values_length}; // one dimensional example - * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); - * // Note we are passing bytes count in this api, not number of elements -> sizeof(values) - * auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values), - * dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16); - * \endcode - * - * Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use - * a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra - * template specialization. - * - * \code{.unparsed} - * namespace yours { struct half {}; } // assume this is your type, define this: - * namespace Ort { - * template<> - * struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; - * } //namespace Ort - * - * std::vector values; - * std::vector dims = {values.size()}; // one dimensional example - * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); - * // Here we are passing element count -> values.size() - * auto float16_tensor = Ort::Value::CreateTensor(info, values.data(), values.size(), dims.data(), dims.size()); - * - * \endcode - */ -struct Float16_t { - uint16_t value; - constexpr Float16_t() noexcept : value(0) {} - constexpr Float16_t(uint16_t v) noexcept : value(v) {} - constexpr operator uint16_t() const noexcept { return value; } - constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; }; - constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; }; -}; - -static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); - -/** \brief bfloat16 (Brain Floating Point) data type - * \details It is necessary for type dispatching to make use of C++ API - * The type is implicitly convertible to/from uint16_t. - * The size of the structure should align with uint16_t and one can freely cast - * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. - * - * See also code examples for Float16_t above. - */ -struct BFloat16_t { - uint16_t value; - constexpr BFloat16_t() noexcept : value(0) {} - constexpr BFloat16_t(uint16_t v) noexcept : value(v) {} - constexpr operator uint16_t() const noexcept { return value; } - constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; }; - constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; }; +/** \brief This is a tagging template type. Use it with Base to indicate that the C++ interface object + * has no ownership of the underlying C object. + */ +template +struct Unowned { + using Type = T; }; -static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); - -/** \brief Used internally by the C++ API. C++ wrapper types inherit from this -* -* This is a zero cost abstraction to wrap the C API objects and delete them on destruction. -* There is a secondary class 'Unowned' that is used to prevent deletion on destruction (Used for return types that are -* not owned by the caller) -* -*/ +/** \brief Used internally by the C++ API. C++ wrapper types inherit from this. + * This is a zero cost abstraction to wrap the C API objects and delete them on destruction. + * + * All of the C++ classes + * a) serve as containers for pointers to objects that are created by the underlying C API. + * Their size is just a pointer size, no need to dynamically allocate them. Use them by value. + * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects. + * they would release objects owned automatically when going out of scope, they are move-only. + * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers. + * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else + * such as Onnxruntime or instances of XXXX classes. + * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used + * in C++ code. + * + */ + +/// +/// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction. +/// template struct Base { using contained_type = T; - Base() = default; - Base(T* p) : p_{p} { - if (!p) - ORT_CXX_API_THROW("Allocation failure", ORT_FAIL); - } + constexpr Base() = default; + constexpr explicit Base(contained_type* p) noexcept : p_{p} {} ~Base() { OrtRelease(p_); } - operator T*() { return p_; } - operator const T*() const { return p_; } + Base(const Base&) = delete; + Base& operator=(const Base&) = delete; + + Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } + Base& operator=(Base&& v) noexcept { + OrtRelease(p_); + p_ = v.release(); + return *this; + } + + constexpr operator contained_type*() const noexcept { return p_; } - /// \brief Releases ownership of the contained pointer - T* release() { + /// \brief Relinquishes ownership of the contained C object pointer + /// The underlying object is not destroyed + contained_type* release() { T* p = p_; p_ = nullptr; return p; } protected: - Base(const Base&) = delete; - Base& operator=(const Base&) = delete; + contained_type* p_{}; +}; + +// Undefined. For const types use Base> +template +struct Base; + +/// +/// Covers unowned pointers owned by either the ORT +/// or some other instance of CPP wrappers. +/// Used for ConstXXX and UnownedXXXX types that are copyable. +/// Also convenient to wrap raw OrtXX pointers . +/// +/// +template +struct Base> { + using contained_type = typename Unowned::Type; + + constexpr Base() = default; + constexpr explicit Base(contained_type* p) noexcept : p_{p} {} + + ~Base() = default; + + Base(const Base&) = default; + Base& operator=(const Base&) = default; + Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } - void operator=(Base&& v) noexcept { - OrtRelease(p_); - p_ = v.release(); + Base& operator=(Base&& v) noexcept { + p_ = nullptr; + std::swap(p_, v.p_); + return *this; } - T* p_{}; + constexpr operator contained_type*() const noexcept { return p_; } - template - friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error + protected: + contained_type* p_{}; }; -/** \brief Wraps an object that inherits from Ort::Base and stops it from deleting the contained pointer on destruction -* -* This has the effect of making it not own the memory held by Ort::Base. -*/ -template -struct Unowned : T { - Unowned(typename T::contained_type* p) : T{p} {} - Unowned(Unowned&& v) : T{v.p_} {} - ~Unowned() { this->release(); } +// Light functor to release memory with OrtAllocator +struct AllocatedFree { + OrtAllocator* allocator_; + explicit AllocatedFree(OrtAllocator* allocator) + : allocator_(allocator) {} + void operator()(void* ptr) const { + if (ptr) allocator_->Free(allocator_, ptr); + } }; +} // namespace detail + struct AllocatorWithDefaultOptions; -struct MemoryInfo; struct Env; struct TypeInfo; struct Value; struct ModelMetadata; +/** \brief unique_ptr typedef used to own strings allocated by OrtAllocators + * and release them at the end of the scope. The lifespan of the given allocator + * must eclipse the lifespan of AllocatedStringPtr instance + */ +using AllocatedStringPtr = std::unique_ptr; + +/** \brief The Status that holds ownership of OrtStatus received from C API + * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate + * constructors to construct an instance of a Status object from exceptions. + */ +struct Status : detail::Base { + explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used + explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. + explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception + explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception + Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message. + std::string GetErrorMessage() const; + OrtErrorCode GetErrorCode() const; + bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status. +}; + +/** \brief The ThreadingOptions + * + * The ThreadingOptions used for set global threadpools' options of The Env. + */ +struct ThreadingOptions : detail::Base { + /// \brief Wraps OrtApi::CreateThreadingOptions + ThreadingOptions(); + + /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads + ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads); + + /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads + ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads); + + /// \brief Wraps OrtApi::SetGlobalSpinControl + ThreadingOptions& SetGlobalSpinControl(int allow_spinning); + + /// \brief Wraps OrtApi::SetGlobalDenormalAsZero + ThreadingOptions& SetGlobalDenormalAsZero(); + + /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn + ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); + + /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions + ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options); + + /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn + ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); +}; + /** \brief The Env (Environment) -* -* The Env holds the logging state used by all other objects. -* Note: One Env must be created before using any other Onnxruntime functionality -*/ -struct Env : Base { + * + * The Env holds the logging state used by all other objects. + * Note: One Env must be created before using any other Onnxruntime functionality + */ +struct Env : detail::Base { explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used /// \brief Wraps OrtApi::CreateEnv @@ -263,22 +717,56 @@ struct Env : Base { Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents + Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel + Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator + + Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2 }; /** \brief Custom Op Domain -* -*/ -struct CustomOpDomain : Base { + * + */ +struct CustomOpDomain : detail::Base { explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used /// \brief Wraps OrtApi::CreateCustomOpDomain explicit CustomOpDomain(const char* domain); - void Add(OrtCustomOp* op); ///< Wraps CustomOpDomain_Add + // This does not take ownership of the op, simply registers it. + void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add +}; + +/// \brief LoraAdapter holds a set of Lora Parameters loaded from a single file +struct LoraAdapter : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit LoraAdapter(std::nullptr_t) {} ///< Create an empty LoraAdapter object, must be assigned a valid one to be used + /// \brief Wraps OrtApi::CreateLoraAdapter + /// + /// The function attempts to load the adapter from the specified file + /// \param adapter_path The path to the Lora adapter + /// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still + /// be copied to device if required by the model at inference time. + static LoraAdapter CreateLoraAdapter(const std::basic_string& adapter_path, + OrtAllocator* allocator); + + /// \brief Wraps OrtApi::CreateLoraAdapterFromArray + /// + /// The function attempts to load the adapter from the specified byte array. + /// \param bytes The byte array containing file LoraAdapter format + /// \param num_bytes The number of bytes in the byte array + /// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still + /// be copied to device if required by the model at inference time. + static LoraAdapter CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes, + OrtAllocator* allocator); }; -struct RunOptions : Base { +/** \brief RunOptions + * + */ +struct RunOptions : detail::Base { explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used RunOptions(); ///< Wraps OrtApi::CreateRunOptions @@ -294,254 +782,798 @@ struct RunOptions : Base { RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance - * - * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error - * Wraps OrtApi::RunOptionsSetTerminate - */ + * + * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error + * Wraps OrtApi::RunOptionsSetTerminate + */ RunOptions& SetTerminate(); /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating - * - * Wraps OrtApi::RunOptionsUnsetTerminate - */ + * + * Wraps OrtApi::RunOptionsUnsetTerminate + */ RunOptions& UnsetTerminate(); + + /** \brief Add the LoraAdapter to the list of active adapters. + * The setting does not affect RunWithBinding() calls. + * + * Wraps OrtApi::RunOptionsAddActiveLoraAdapter + * \param adapter The LoraAdapter to be used as the active adapter + */ + RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter); +}; + +namespace detail { +// Utility function that returns a SessionOption config entry key for a specific custom operator. +// Ex: custom_op.[custom_op_name].[config] +std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config); +} // namespace detail + +/// +/// Class that represents session configuration entries for one or more custom operators. +/// +/// Example: +/// Ort::CustomOpConfigs op_configs; +/// op_configs.AddConfig("my_custom_op", "device_type", "CPU"); +/// +/// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary. +/// +struct CustomOpConfigs { + CustomOpConfigs() = default; + ~CustomOpConfigs() = default; + CustomOpConfigs(const CustomOpConfigs&) = default; + CustomOpConfigs& operator=(const CustomOpConfigs&) = default; + CustomOpConfigs(CustomOpConfigs&& o) = default; + CustomOpConfigs& operator=(CustomOpConfigs&& o) = default; + + /** \brief Adds a session configuration entry/value for a specific custom operator. + * + * \param custom_op_name The name of the custom operator for which to add a configuration entry. + * Must match the name returned by the CustomOp's GetName() method. + * \param config_key The name of the configuration entry. + * \param config_value The value of the configuration entry. + * \return A reference to this object to enable call chaining. + */ + CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value); + + /** \brief Returns a flattened map of custom operator configuration entries and their values. + * + * The keys has been flattened to include both the custom operator name and the configuration entry key name. + * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair + * {"my_op.key", "value"}. + * + * \return An unordered map of flattened configurations. + */ + const std::unordered_map& GetFlattenedConfigs() const; + + private: + std::unordered_map flat_configs_; }; /** \brief Options object used when creating a new Session object -* -* Wraps ::OrtSessionOptions object and methods -*/ -struct SessionOptions : Base { - explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used - SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions - explicit SessionOptions(OrtSessionOptions* p) : Base{p} {} ///< Used for interop with the C API + * + * Wraps ::OrtSessionOptions object and methods + */ + +struct SessionOptions; + +namespace detail { +// we separate const-only methods because passing const ptr to non-const methods +// is only discovered when inline methods are compiled which is counter-intuitive +template +struct ConstSessionOptionsImpl : Base { + using B = Base; + using B::B; SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions - SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads - SessionOptions& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads - SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel + std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry + bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry + std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def); +}; + +template +struct SessionOptionsImpl : ConstSessionOptionsImpl { + using B = ConstSessionOptionsImpl; + using B::B; + + SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads + SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads + SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel + SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute + + SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena + SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena + + SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath - SessionOptions& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena - SessionOptions& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena + SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling + SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling - SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath + SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps - SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling - SessionOptions& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling + SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern + SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern - SessionOptions& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps + SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode - SessionOptions& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern - SessionOptions& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern + SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId + SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel - SessionOptions& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode + SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain - SessionOptions& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId - SessionOptions& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel + SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads - SessionOptions& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain + SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry - SessionOptions& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads + SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer + SessionOptionsImpl& AddExternalInitializers(const std::vector& names, const std::vector& ort_values); ///< Wraps OrtApi::AddExternalInitializers + SessionOptionsImpl& AddExternalInitializersFromFilesInMemory(const std::vector>& external_initializer_file_names, + const std::vector& external_initializer_file_buffer_array, + const std::vector& external_initializer_file_lengths); ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory - SessionOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry - SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer + SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA + SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2 + SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM + SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2 + SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map& provider_options = {}); + SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN + SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl + SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. + SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, + const std::unordered_map& provider_options = {}); - SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA - SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM - SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO - SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn + SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions + SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn - SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn - SessionOptions& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions - SessionOptions& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn + ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2. + ///< The custom operator configurations are optional. If provided, custom operator configs are set via + ///< OrtApi::AddSessionConfigEntry. + SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {}); + + SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction + + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI + SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options = {}); +}; +} // namespace detail + +using UnownedSessionOptions = detail::SessionOptionsImpl>; +using ConstSessionOptions = detail::ConstSessionOptionsImpl>; + +/** \brief Wrapper around ::OrtSessionOptions + * + */ +struct SessionOptions : detail::SessionOptionsImpl { + explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used + SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions + explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl{p} {} ///< Used for interop with the C API + UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; } + ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; } }; /** \brief Wrapper around ::OrtModelMetadata -* -*/ -struct ModelMetadata : Base { + * + */ +struct ModelMetadata : detail::Base { explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used explicit ModelMetadata(OrtModelMetadata* p) : Base{p} {} ///< Used for interop with the C API - char* GetProducerName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName - char* GetGraphName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName - char* GetDomain(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain - char* GetDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription - char* GetGraphDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription - char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys - char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap - int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion + /** \brief Returns a copy of the producer name. + * + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName + + /** \brief Returns a copy of the graph name. + * + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName + + /** \brief Returns a copy of the domain name. + * + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain + + /** \brief Returns a copy of the description. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription + + /** \brief Returns a copy of the graph description. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription + + /** \brief Returns a vector of copies of the custom metadata keys. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope. + * The OrtAllocator instance must be valid at the point of memory release. + */ + std::vector GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys + + /** \brief Looks up a value by a key in the Custom Metadata map + * + * \param key zero terminated string key to lookup + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * maybe nullptr if key is not found. + * + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap + + int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion }; -/** \brief Wrapper around ::OrtSession -* -*/ -struct Session : Base { - explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used - Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession - Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer - Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray - - /** \brief Run the model returning results in an Ort allocated vector. - * - * Wraps OrtApi::Run - * - * The caller provides a list of inputs and a list of the desired outputs to return. - * - * See the output logs for more information on warnings/errors that occur while processing the model. - * Common errors are.. (TODO) - * - * \param[in] run_options - * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names - * \param[in] input_values Array of Value objects of length input_count that is the list of input values - * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays) - * \param[in] output_names Array of C style strings of length output_count that is the list of output names - * \param[in] output_count Number of outputs (the size of the output_names array) - * \return A std::vector of Value objects that directly maps to the output_count (eg. output_name[0] is the first entry of the returned vector) - */ - std::vector Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, size_t output_count); +struct IoBinding; - /** \brief Run the model returning results in user provided outputs - * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t) - */ - void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, Value* output_values, size_t output_count); +namespace detail { - void Run(const RunOptions& run_options, const struct IoBinding&); ///< Wraps OrtApi::RunWithBinding +// we separate const-only methods because passing const ptr to non-const methods +// is only discovered when inline methods are compiled which is counter-intuitive +template +struct ConstSessionImpl : Base { + using B = Base; + using B::B; size_t GetInputCount() const; ///< Returns the number of model inputs size_t GetOutputCount() const; ///< Returns the number of model outputs size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden - char* GetInputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetInputName - char* GetOutputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOutputName - char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName - char* EndProfiling(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling - uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs - ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata + /** \brief Returns a copy of input name at the specified index. + * + * \param index must less than the value returned by GetInputCount() + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const; + + /** \brief Returns a copy of output name at then specified index. + * + * \param index must less than the value returned by GetOutputCount() + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const; + + /** \brief Returns a copy of the overridable initializer name at then specified index. + * + * \param index must less than the value returned by GetOverridableInitializerCount() + * \param allocator to allocate memory for the copy of the name returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName + + uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs + ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo }; -/** \brief Wrapper around ::OrtTensorTypeAndShapeInfo -* -*/ -struct TensorTypeAndShapeInfo : Base { - explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used - explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : Base{p} {} ///< Used for interop with the C API +template +struct SessionImpl : ConstSessionImpl { + using B = ConstSessionImpl; + using B::B; - ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType - size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount + /** \brief Run the model returning results in an Ort allocated vector. + * + * Wraps OrtApi::Run + * + * The caller provides a list of inputs and a list of the desired outputs to return. + * + * See the output logs for more information on warnings/errors that occur while processing the model. + * Common errors are.. (TODO) + * + * \param[in] run_options + * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names + * \param[in] input_values Array of Value objects of length input_count that is the list of input values + * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays) + * \param[in] output_names Array of C style strings of length output_count that is the list of output names + * \param[in] output_count Number of outputs (the size of the output_names array) + * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector) + */ + std::vector Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, size_t output_count); - size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount - void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions - void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions + /** \brief Run the model returning results in user provided outputs + * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t) + */ + void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count); - std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape + void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding + + /** \brief Run the model asynchronously in a thread owned by intra op thread pool + * + * Wraps OrtApi::RunAsync + * + * \param[in] run_options + * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names + * \param[in] input_values Array of Value objects of length input_count + * \param[in] input_count Number of elements in the input_names and inputs arrays + * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names + * \param[out] output_values Array of provided Values to be filled with outputs. + * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*. + * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime. + * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback. + * NOTE: it is customer's duty to finally release output_values and each of its member, + * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer. + * \param[in] output_count Number of elements in the output_names and outputs array + * \param[in] callback Callback function on model run completion + * \param[in] user_data User data that pass back to the callback + */ + void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data); + + /** \brief End profiling and return a copy of the profiling file name. + * + * \param allocator to allocate memory for the copy of the string returned + * \return a instance of smart pointer that would deallocate the buffer when out of scope. + * The OrtAllocator instances must be valid at the point of memory release. + */ + AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling + + /** \brief Set DynamicOptions for EPs (Execution Providers) + * + * Wraps OrtApi::SetEpDynamicOptions + * + * Valid options can be found in `include\onnxruntime\core\session\onnxruntime_session_options_config_keys.h` + * Look for `kOrtEpDynamicOptions` + * + * \param[in] keys Array of null terminated UTF8 encoded strings of EP dynamic option keys + * \param[in] values Array of null terminated UTF8 encoded string of EP dynamic option values + * \param[in] kv_len Number of elements in the keys and values arrays + */ + void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len); }; -/** \brief Wrapper around ::OrtSequenceTypeInfo -* -*/ -struct SequenceTypeInfo : Base { - explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used - explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : Base{p} {} ///< Used for interop with the C API - - TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType -}; +} // namespace detail -/** \brief Wrapper around ::OrtMapTypeInfo -* -*/ -struct MapTypeInfo : Base { - explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used - explicit MapTypeInfo(OrtMapTypeInfo* p) : Base{p} {} ///< Used for interop with the C API +using ConstSession = detail::ConstSessionImpl>; +using UnownedSession = detail::SessionImpl>; - ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType - TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType +/** \brief Wrapper around ::OrtSession + * + */ +struct Session : detail::SessionImpl { + explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used + Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession + Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer + Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray + Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer + + ConstSession GetConst() const { return ConstSession{this->p_}; } + UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } }; -struct TypeInfo : Base { - explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used - explicit TypeInfo(OrtTypeInfo* p) : Base{p} {} ///< C API Interop +namespace detail { +template +struct MemoryInfoImpl : Base { + using B = Base; + using B::B; - Unowned GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo - Unowned GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo - Unowned GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo + std::string GetAllocatorName() const; + OrtAllocatorType GetAllocatorType() const; + int GetDeviceId() const; + OrtMemoryInfoDeviceType GetDeviceType() const; + OrtMemType GetMemoryType() const; - ONNXType GetONNXType() const; + template + bool operator==(const MemoryInfoImpl& o) const; }; +} // namespace detail -struct Value : Base { - // This structure is used to feed sparse tensor values - // information for use with FillSparseTensor() API - // if the data type for the sparse tensor values is numeric - // use data.p_data, otherwise, use data.str pointer to feed - // values. data.str is an array of const char* that are zero terminated. - // number of strings in the array must match shape size. - // For fully sparse tensors use shape {0} and set p_data/str - // to nullptr. - struct OrtSparseValuesParam { - const int64_t* values_shape; - size_t values_shape_len; - union { - const void* p_data; - const char** str; - } data; - }; +// Const object holder that does not own the underlying object +using ConstMemoryInfo = detail::MemoryInfoImpl>; - // Provides a way to pass shape in a single - // argument - struct Shape { - const int64_t* shape; - size_t shape_len; - }; +/** \brief Wrapper around ::OrtMemoryInfo + * + */ +struct MemoryInfo : detail::MemoryInfoImpl { + static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); + explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created + explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C Api + MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); + ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } +}; - /// \brief Wraps OrtApi::CreateTensorWithDataAsOrtValue - template - static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); - /// \brief Wraps OrtApi::CreateTensorWithDataAsOrtValue - static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type); +namespace detail { +template +struct TensorTypeAndShapeInfoImpl : Base { + using B = Base; + using B::B; + + ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType + size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount + + size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount + + /** \deprecated use GetShape() returning std::vector + * [[deprecated]] + * This interface is unsafe to use + */ + [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions + + void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions + + std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape +}; + +} // namespace detail + +using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl>; + +/** \brief Wrapper around ::OrtTensorTypeAndShapeInfo + * + */ +struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl { + explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used + explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API + ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; } +}; + +namespace detail { +template +struct SequenceTypeInfoImpl : Base { + using B = Base; + using B::B; + TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType +}; + +} // namespace detail + +using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl>; + +/** \brief Wrapper around ::OrtSequenceTypeInfo + * + */ +struct SequenceTypeInfo : detail::SequenceTypeInfoImpl { + explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used + explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl{p} {} ///< Used for interop with the C API + ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; } +}; + +namespace detail { +template +struct OptionalTypeInfoImpl : Base { + using B = Base; + using B::B; + TypeInfo GetOptionalElementType() const; ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo +}; + +} // namespace detail + +// This is always owned by the TypeInfo and can only be obtained from it. +using ConstOptionalTypeInfo = detail::OptionalTypeInfoImpl>; + +namespace detail { +template +struct MapTypeInfoImpl : detail::Base { + using B = Base; + using B::B; + ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType + TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType +}; + +} // namespace detail + +using ConstMapTypeInfo = detail::MapTypeInfoImpl>; + +/** \brief Wrapper around ::OrtMapTypeInfo + * + */ +struct MapTypeInfo : detail::MapTypeInfoImpl { + explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used + explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl{p} {} ///< Used for interop with the C API + ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; } +}; + +namespace detail { +template +struct TypeInfoImpl : detail::Base { + using B = Base; + using B::B; + + ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo + ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo + ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo + ConstOptionalTypeInfo GetOptionalTypeInfo() const; ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo + + ONNXType GetONNXType() const; +}; +} // namespace detail + +/// +/// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value. +/// Provides access to const OrtTypeInfo APIs. +/// +using ConstTypeInfo = detail::TypeInfoImpl>; + +/// +/// Type information that may contain either TensorTypeAndShapeInfo or +/// the information about contained sequence or map depending on the ONNXType. +/// +struct TypeInfo : detail::TypeInfoImpl { + explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used + explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl{p} {} ///< C API Interop + + ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; } +}; + +namespace detail { +// This structure is used to feed sparse tensor values +// information for use with FillSparseTensor() API +// if the data type for the sparse tensor values is numeric +// use data.p_data, otherwise, use data.str pointer to feed +// values. data.str is an array of const char* that are zero terminated. +// number of strings in the array must match shape size. +// For fully sparse tensors use shape {0} and set p_data/str +// to nullptr. +struct OrtSparseValuesParam { + const int64_t* values_shape; + size_t values_shape_len; + union { + const void* p_data; + const char** str; + } data; +}; + +// Provides a way to pass shape in a single +// argument +struct Shape { + const int64_t* shape; + size_t shape_len; +}; + +template +struct ConstValueImpl : Base { + using B = Base; + using B::B; + + /// + /// Obtains a pointer to a user defined data for experimental purposes + /// + template + void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue + + bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc + bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None + + size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements + Value GetValue(int index, OrtAllocator* allocator) const; + + /// + /// This API returns a full length of string data contained within either a tensor or a sparse Tensor. + /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful + /// for allocating necessary memory and calling GetStringTensorContent(). + /// + /// total length of UTF-8 encoded bytes contained. No zero terminators counted. + size_t GetStringTensorDataLength() const; + + /// + /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor + /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate. + /// The user must also allocate offsets buffer with the number of entries equal to that of the contained + /// strings. + /// + /// Strings are always assumed to be on CPU, no X-device copy. + /// + /// user allocated buffer + /// length in bytes of the allocated buffer + /// a pointer to the offsets user allocated buffer + /// count of offsets, must be equal to the number of strings contained. + /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo() + /// for sparse tensors + void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; + + /// + /// Returns a const typed pointer to the tensor contained data. + /// No type checking is performed, the caller must ensure the type matches the tensor type. + /// + /// + /// const pointer to data, no copies made + template + const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// + + /// + /// Returns a non-typed pointer to a tensor contained data. + /// + /// const pointer to data, no copies made + const void* GetTensorRawData() const; + + /// + /// The API returns type information for data contained in a tensor. For sparse + /// tensors it returns type information for contained non-zero values. + /// It returns dense shape for sparse tensors. + /// + /// TypeInfo + TypeInfo GetTypeInfo() const; + + /// + /// The API returns type information for data contained in a tensor. For sparse + /// tensors it returns type information for contained non-zero values. + /// It returns dense shape for sparse tensors. + /// + /// TensorTypeAndShapeInfo + TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; + + /// + /// This API returns information about the memory allocation used to hold data. + /// + /// Non owning instance of MemoryInfo + ConstMemoryInfo GetTensorMemoryInfo() const; + + /// + /// The API copies UTF-8 encoded bytes for the requested string element + /// contained within a tensor or a sparse tensor into a provided buffer. + /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate. + /// + /// + /// + /// + void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const; + + /// + /// Returns string tensor UTF-8 encoded string element. + /// Use of this API is recommended over GetStringTensorElement() that takes void* buffer pointer. + /// + /// + /// std::string + std::string GetStringTensorElement(size_t element_index) const; + + /// + /// The API returns a byte length of UTF-8 encoded string element + /// contained in either a tensor or a spare tensor values. + /// + /// + /// byte length for the specified string element + size_t GetStringTensorElementLength(size_t element_index) const; #if !defined(DISABLE_SPARSE_TENSORS) /// - /// This is a simple forwarding method to the other overload that helps deducing - /// data type enum value from the type of the buffer. + /// The API returns the sparse data format this OrtValue holds in a sparse tensor. + /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used + /// the value returned is ORT_SPARSE_UNDEFINED. + /// + /// Format enum + OrtSparseFormat GetSparseFormat() const; + + /// + /// The API returns type and shape information for stored non-zero values of the + /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer. + /// + /// TensorTypeAndShapeInfo values information + TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const; + + /// + /// The API returns type and shape information for the specified indices. Each supported + /// indices have their own enum values even if a give format has more than one kind of indices. + /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer. + /// + /// enum requested + /// type and shape information + TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const; + + /// + /// The API retrieves a pointer to the internal indices buffer. The API merely performs + /// a convenience data type casting on the return type pointer. Make sure you are requesting + /// the right type, use GetSparseTensorIndicesTypeShapeInfo(); + /// + /// type to cast to + /// requested indices kind + /// number of indices entries + /// Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer. + template + const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const; + + /// + /// Returns true if the OrtValue contains a sparse tensor /// - /// numeric datatype. This API is not suitable for strings. - /// Memory description where the user buffers reside (CPU vs GPU etc) - /// pointer to the user supplied buffer, use nullptr for fully sparse tensors - /// a would be dense shape of the tensor - /// non zero values shape. Use a single 0 shape for fully sparse tensors. /// - template - static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, - const Shape& values_shape); + bool IsSparseTensor() const; /// - /// Creates an OrtValue instance containing SparseTensor. This constructs - /// a sparse tensor that makes use of user allocated buffers. It does not make copies - /// of the user provided data and does not modify it. The lifespan of user provided buffers should - /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain - /// a pointer to non-zero values. To fully populate the sparse tensor call UseIndices() API below - /// to supply a sparse format specific indices. - /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings - /// can be properly copied into the allocated buffer. + /// The API returns a pointer to an internal buffer of the sparse tensor + /// containing non-zero values. The API merely does casting. Make sure you + /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo() + /// first. /// - /// Memory description where the user buffers reside (CPU vs GPU etc) - /// pointer to the user supplied buffer, use nullptr for fully sparse tensors - /// a would be dense shape of the tensor - /// non zero values shape. Use a single 0 shape for fully sparse tensors. - /// data type - /// Ort::Value instance containing SparseTensor - static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, - const Shape& values_shape, ONNXTensorElementDataType type); + /// numeric data types only. Use GetStringTensor*() to retrieve strings. + /// a pointer to the internal values buffer. Do not free this pointer. + template + const R* GetSparseTensorValues() const; + +#endif +}; + +template +struct ValueImpl : ConstValueImpl { + using B = ConstValueImpl; + using B::B; + + /// + /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer + /// No type checking is performed, the caller must ensure the type matches the tensor type. + /// + /// non-const pointer to data, no copies made + template + R* GetTensorMutableData(); + + /// + /// Returns a non-typed non-const pointer to a tensor contained data. + /// + /// pointer to data, no copies made + void* GetTensorMutableRawData(); + + /// + // Obtain a reference to an element of data at the location specified + /// by the vector of dims. + /// + /// + /// [in] expressed by a vecotr of dimensions offsets + /// + template + R& At(const std::vector& location); + + /// + /// Set all strings at once in a string tensor + /// + /// [in] An array of strings. Each string in this array must be null terminated. + /// [in] Count of strings in s (Must match the size of \p value's tensor shape) + void FillStringTensor(const char* const* s, size_t s_len); + + /// + /// Set a single string in a string tensor + /// + /// [in] A null terminated UTF-8 encoded string + /// [in] Index of the string in the tensor to set + void FillStringTensorElement(const char* s, size_t index); + + /// + /// Allocate if necessary and obtain a pointer to a UTF-8 + /// encoded string element buffer indexed by the flat element index, + /// of the specified length. + /// + /// This API is for advanced usage. It avoids a need to construct + /// an auxiliary array of string pointers, and allows to write data directly + /// (do not zero terminate). + /// + /// + /// + /// a pointer to a writable buffer + char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length); +#if !defined(DISABLE_SPARSE_TENSORS) /// /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor. /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user @@ -574,40 +1606,6 @@ struct Value : Base { /// user allocated buffer with indices or nullptr for fully spare tensors void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data); -#endif // !defined(DISABLE_SPARSE_TENSORS) - - // \brief Wraps OrtApi::CreateTensorAsOrtValue - template - static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len); - // \brief Wraps OrtApi::CreateTensorAsOrtValue - static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); - -#if !defined(DISABLE_SPARSE_TENSORS) - /// - /// This is a simple forwarding method the below CreateSparseTensor. - /// This helps to specify data type enum in terms of C++ data type. - /// Use CreateSparseTensor - /// - /// numeric data type only. String data enum must be specified explicitly. - /// allocator to use - /// a would be dense shape of the tensor - /// Ort::Value - template - static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape); - - /// - /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data. - /// The data must be supplied by on of the FillSparseTensor() methods that take both non-zero values - /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator. - /// Use this API to create OrtValues that contain sparse tensors with all supported data types including - /// strings. - /// - /// allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue - /// a would be dense shape of the tensor - /// data type - /// an instance of Ort::Value - static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type); - /// /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located @@ -641,166 +1639,187 @@ struct Value : Base { /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located /// at difference device than the allocator, a X-device copy will be performed if possible. /// - /// specified buffer memory description - /// values buffer information - /// indices shape. use {0} for fully sparse tensors - /// pointer to indices data or nullptr for fully sparse tensors - void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, - const OrtSparseValuesParam& values, - const Shape& indices_shape, - const int32_t* indices_data); - - /// - /// The API returns the sparse data format this OrtValue holds in a sparse tensor. - /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used - /// the value returned is ORT_SPARSE_UNDEFINED. - /// - /// Format enum - OrtSparseFormat GetSparseFormat() const; - - /// - /// The API returns type and shape information for stored non-zero values of the - /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer. - /// - /// TensorTypeAndShapeInfo values information - TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const; - - /// - /// The API returns type and shape information for the specified indices. Each supported - /// indices have their own enum values even if a give format has more than one kind of indices. - /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer. - /// - /// enum requested - /// type and shape information - TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const; - - /// - /// The API retrieves a pointer to the internal indices buffer. The API merely performs - /// a convenience data type casting on the return type pointer. Make sure you are requesting - /// the right type, use GetSparseTensorIndicesTypeShapeInfo(); - /// - /// type to cast to - /// requested indices kind - /// number of indices entries - /// Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer. - template - const T* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const; + /// specified buffer memory description + /// values buffer information + /// indices shape. use {0} for fully sparse tensors + /// pointer to indices data or nullptr for fully sparse tensors + void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, + const OrtSparseValuesParam& values, + const Shape& indices_shape, + const int32_t* indices_data); -#endif // !defined(DISABLE_SPARSE_TENSORS) +#endif +}; - static Value CreateMap(Value& keys, Value& values); ///< Wraps OrtApi::CreateValue - static Value CreateSequence(std::vector& values); ///< Wraps OrtApi::CreateValue +} // namespace detail - template - static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue +using ConstValue = detail::ConstValueImpl>; +using UnownedValue = detail::ValueImpl>; - template - void GetOpaqueData(const char* domain, const char* type_name, T&) const; ///< Wraps OrtApi::GetOpaqueValue +/** \brief Wrapper around ::OrtValue + * + */ +struct Value : detail::ValueImpl { + using Base = detail::ValueImpl; + using OrtSparseValuesParam = detail::OrtSparseValuesParam; + using Shape = detail::Shape; - explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used - explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API + explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used + explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API Value(Value&&) = default; Value& operator=(Value&&) = default; - bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc - bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None - -#if !defined(DISABLE_SPARSE_TENSORS) - /// - /// Returns true if the OrtValue contains a sparse tensor - /// - /// - bool IsSparseTensor() const; -#endif - - size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements - Value GetValue(int index, OrtAllocator* allocator) const; - - /// - /// This API returns a full length of string data contained within either a tensor or a sparse Tensor. - /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful - /// for allocating necessary memory and calling GetStringTensorContent(). - /// - /// total length of UTF-8 encoded bytes contained. No zero terminators counted. - size_t GetStringTensorDataLength() const; + ConstValue GetConst() const { return ConstValue{this->p_}; } + UnownedValue GetUnowned() const { return UnownedValue{this->p_}; } + + /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. + * \tparam T The numeric datatype. This API is not suitable for strings. + * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc). + * \param p_data Pointer to the data buffer. + * \param p_data_element_count The number of elements in the data buffer. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + */ + template + static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); - /// - /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor - /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate. - /// The user must also allocate offsets buffer with the number of entries equal to that of the contained - /// strings. - /// - /// Strings are always assumed to be on CPU, no X-device copy. - /// - /// user allocated buffer - /// length in bytes of the allocated buffer - /// a pointer to the offsets user allocated buffer - /// count of offsets, must be equal to the number of strings contained. - /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo() - /// for sparse tensors - void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; + /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. + * + * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc). + * \param p_data Pointer to the data buffer. + * \param p_data_byte_count The number of bytes in the data buffer. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + * \param type The data type. + */ + static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); + /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue. + * This overload will allocate the buffer for the tensor according to the supplied shape and data type. + * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. + * The input data would need to be copied into the allocated buffer. + * This API is not suitable for strings. + * + * \tparam T The numeric datatype. This API is not suitable for strings. + * \param allocator The allocator to use. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + */ template - T* GetTensorMutableData(); ///< Wraps OrtApi::GetTensorMutableData + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len); + + /** \brief Creates an OrtValue with a tensor using the supplied OrtAllocator. + * Wraps OrtApi::CreateTensorAsOrtValue. + * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. + * The input data would need to be copied into the allocated buffer. + * This API is not suitable for strings. + * + * \param allocator The allocator to use. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + * \param type The data type. + */ + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); + /** \brief Creates an OrtValue with a Map Onnx type representation. + * The API would ref-count the supplied OrtValues and they will be released + * when the returned OrtValue is released. The caller may release keys and values after the call + * returns. + * + * \param keys an OrtValue containing a tensor with primitive data type keys. + * \param values an OrtValue that may contain a tensor. Ort currently supports only primitive data type values. + */ + static Value CreateMap(const Value& keys, const Value& values); ///< Wraps OrtApi::CreateValue + + /** \brief Creates an OrtValue with a Sequence Onnx type representation. + * The API would ref-count the supplied OrtValues and they will be released + * when the returned OrtValue is released. The caller may release the values after the call + * returns. + * + * \param values a vector of OrtValues that must have the same Onnx value type. + */ + static Value CreateSequence(const std::vector& values); ///< Wraps OrtApi::CreateValue + + /** \brief Creates an OrtValue wrapping an Opaque type. + * This is used for experimental support of non-tensor types. + * + * \tparam T - the type of the value. + * \param domain - zero terminated utf-8 string. Domain of the type. + * \param type_name - zero terminated utf-8 string. Name of the type. + * \param value - the value to be wrapped. + */ template - const T* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData + static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue #if !defined(DISABLE_SPARSE_TENSORS) /// - /// The API returns a pointer to an internal buffer of the sparse tensor - /// containing non-zero values. The API merely does casting. Make sure you - /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo() - /// first. + /// This is a simple forwarding method to the other overload that helps deducing + /// data type enum value from the type of the buffer. /// - /// numeric data types only. Use GetStringTensor*() to retrieve strings. - /// a pointer to the internal values buffer. Do not free this pointer. - template - const T* GetSparseTensorValues() const; -#endif - + /// numeric datatype. This API is not suitable for strings. + /// Memory description where the user buffers reside (CPU vs GPU etc) + /// pointer to the user supplied buffer, use nullptr for fully sparse tensors + /// a would be dense shape of the tensor + /// non zero values shape. Use a single 0 shape for fully sparse tensors. + /// template - T& At(const std::vector& location); - - /// - /// The API returns type information for data contained in a tensor. For sparse - /// tensors it returns type information for contained non-zero values. - /// It returns dense shape for sparse tensors. - /// - /// TypeInfo - TypeInfo GetTypeInfo() const; + static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, + const Shape& values_shape); /// - /// The API returns type information for data contained in a tensor. For sparse - /// tensors it returns type information for contained non-zero values. - /// It returns dense shape for sparse tensors. + /// Creates an OrtValue instance containing SparseTensor. This constructs + /// a sparse tensor that makes use of user allocated buffers. It does not make copies + /// of the user provided data and does not modify it. The lifespan of user provided buffers should + /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain + /// a pointer to non-zero values. To fully populate the sparse tensor call UseIndices() API below + /// to supply a sparse format specific indices. + /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings + /// can be properly copied into the allocated buffer. /// - /// TensorTypeAndShapeInfo - TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; + /// Memory description where the user buffers reside (CPU vs GPU etc) + /// pointer to the user supplied buffer, use nullptr for fully sparse tensors + /// a would be dense shape of the tensor + /// non zero values shape. Use a single 0 shape for fully sparse tensors. + /// data type + /// Ort::Value instance containing SparseTensor + static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, + const Shape& values_shape, ONNXTensorElementDataType type); /// - /// The API returns a byte length of UTF-8 encoded string element - /// contained in either a tensor or a spare tensor values. + /// This is a simple forwarding method to the below CreateSparseTensor. + /// This helps to specify data type enum in terms of C++ data type. + /// Use CreateSparseTensor /// - /// - /// byte length for the specified string element - size_t GetStringTensorElementLength(size_t element_index) const; + /// numeric data type only. String data enum must be specified explicitly. + /// allocator to use + /// a would be dense shape of the tensor + /// Ort::Value + template + static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape); /// - /// The API copies UTF-8 encoded bytes for the requested string element - /// contained within a tensor or a sparse tensor into a provided buffer. - /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate. + /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data. + /// The data must be supplied by on of the FillSparseTensor() methods that take both non-zero values + /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator. + /// Use this API to create OrtValues that contain sparse tensors with all supported data types including + /// strings. /// - /// - /// - /// - void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const; + /// allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue + /// a would be dense shape of the tensor + /// data type + /// an instance of Ort::Value + static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type); - void FillStringTensor(const char* const* s, size_t s_len); - void FillStringTensorElement(const char* s, size_t index); +#endif // !defined(DISABLE_SPARSE_TENSORS) }; -// Represents native memory allocation +/// +/// Represents native memory allocation coming from one of the +/// OrtAllocators registered with OnnxRuntime. +/// Use it to wrap an allocation made by an allocator +/// so it can be automatically released when no longer needed. +/// struct MemoryAllocation { MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); ~MemoryAllocation(); @@ -818,81 +1837,99 @@ struct MemoryAllocation { size_t size_; }; -struct AllocatorWithDefaultOptions { - AllocatorWithDefaultOptions(); - - operator OrtAllocator*() { return p_; } - operator const OrtAllocator*() const { return p_; } +namespace detail { +template +struct AllocatorImpl : Base { + using B = Base; + using B::B; void* Alloc(size_t size); - // The return value will own the allocation MemoryAllocation GetAllocation(size_t size); void Free(void* p); + ConstMemoryInfo GetInfo() const; +}; - const OrtMemoryInfo* GetInfo() const; +} // namespace detail - private: - OrtAllocator* p_{}; +/** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime + * + */ +struct AllocatorWithDefaultOptions : detail::AllocatorImpl> { + explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance + AllocatorWithDefaultOptions(); }; -struct MemoryInfo : Base { - static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); - - explicit MemoryInfo(std::nullptr_t) {} - explicit MemoryInfo(OrtMemoryInfo* p) : Base{p} {} ///< Used for interop with the C API - MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); +/** \brief Wrapper around ::OrtAllocator + * + */ +struct Allocator : detail::AllocatorImpl { + explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance + Allocator(const Session& session, const OrtMemoryInfo*); +}; - std::string GetAllocatorName() const; - OrtAllocatorType GetAllocatorType() const; - int GetDeviceId() const; - OrtMemType GetMemoryType() const; +using UnownedAllocator = detail::AllocatorImpl>; - bool operator==(const MemoryInfo& o) const; -}; +namespace detail { +namespace binding_utils { +// Bring these out of template +std::vector GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*); +std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*); +} // namespace binding_utils -struct Allocator : public Base { - Allocator(const Session& session, const MemoryInfo&); +template +struct ConstIoBindingImpl : Base { + using B = Base; + using B::B; - void* Alloc(size_t size) const; - // The return value will own the allocation - MemoryAllocation GetAllocation(size_t size); - void Free(void* p) const; - Unowned GetInfo() const; + std::vector GetOutputNames() const; + std::vector GetOutputNames(OrtAllocator*) const; + std::vector GetOutputValues() const; + std::vector GetOutputValues(OrtAllocator*) const; }; -struct IoBinding : public Base { - explicit IoBinding(Session& session); +template +struct IoBindingImpl : ConstIoBindingImpl { + using B = ConstIoBindingImpl; + using B::B; + void BindInput(const char* name, const Value&); void BindOutput(const char* name, const Value&); - void BindOutput(const char* name, const MemoryInfo&); - std::vector GetOutputNames() const; - std::vector GetOutputNames(Allocator&) const; - std::vector GetOutputValues() const; - std::vector GetOutputValues(Allocator&) const; + void BindOutput(const char* name, const OrtMemoryInfo*); void ClearBoundInputs(); void ClearBoundOutputs(); void SynchronizeInputs(); void SynchronizeOutputs(); +}; - private: - std::vector GetOutputNamesHelper(OrtAllocator*) const; - std::vector GetOutputValuesHelper(OrtAllocator*) const; +} // namespace detail + +using ConstIoBinding = detail::ConstIoBindingImpl>; +using UnownedIoBinding = detail::IoBindingImpl>; + +/** \brief Wrapper around ::OrtIoBinding + * + */ +struct IoBinding : detail::IoBindingImpl { + explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later. + explicit IoBinding(Session& session); + ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; } + UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; } }; /*! \struct Ort::ArenaCfg - * \brief it is a structure that represents the configuration of an arena based allocator - * \details Please see docs/C_API.md for details - */ -struct ArenaCfg : Base { + * \brief it is a structure that represents the configuration of an arena based allocator + * \details Please see docs/C_API.md for details + */ +struct ArenaCfg : detail::Base { explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used /** - * Wraps OrtApi::CreateArenaCfg - * \param max_mem - use 0 to allow ORT to choose the default - * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested - * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default - * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default - * See docs/C_API.md for details on what the following parameters mean and how to choose these values - */ + * Wraps OrtApi::CreateArenaCfg + * \param max_mem - use 0 to allow ORT to choose the default + * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested + * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default + * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default + * See docs/C_API.md for details on what the following parameters mean and how to choose these values + */ ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); }; @@ -900,60 +1937,426 @@ struct ArenaCfg : Base { // Custom OPs (only needed to implement custom OPs) // -struct CustomOpApi { - CustomOpApi(const OrtApi& api) : api_(api) {} +/// +/// This struct provides life time management for custom op attribute +/// +struct OpAttr : detail::Base { + OpAttr(const char* name, const void* data, int len, OrtOpAttrType type); +}; + +/** + * Macro that logs a message using the provided logger. Throws an exception if OrtApi::Logger_LogMessage fails. + * Example: ORT_CXX_LOG(logger, ORT_LOGGING_LEVEL_INFO, "Log a message"); + * + * \param logger The Ort::Logger instance to use. Must be a value or reference. + * \param message_severity The logging severity level of the message. + * \param message A null-terminated UTF-8 message to log. + */ +#define ORT_CXX_LOG(logger, message_severity, message) \ + do { \ + if (message_severity >= logger.GetLoggingSeverityLevel()) { \ + Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \ + static_cast(__FUNCTION__), message)); \ + } \ + } while (false) - template // T is only implemented for std::vector, std::vector, float, int64_t, and string - T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name); +/** + * Macro that logs a message using the provided logger. Can be used in noexcept code since errors are silently ignored. + * Example: ORT_CXX_LOG_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log a message"); + * + * \param logger The Ort::Logger instance to use. Must be a value or reference. + * \param message_severity The logging severity level of the message. + * \param message A null-terminated UTF-8 message to log. + */ +#define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \ + do { \ + if (message_severity >= logger.GetLoggingSeverityLevel()) { \ + static_cast(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \ + static_cast(__FUNCTION__), message)); \ + } \ + } while (false) - OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value); - size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info); - ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info); - size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info); - void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); - void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); +/** + * Macro that logs a printf-like formatted message using the provided logger. Throws an exception if + * OrtApi::Logger_LogMessage fails or if a formatting error occurs. + * Example: ORT_CXX_LOGF(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12); + * + * \param logger The Ort::Logger instance to use. Must be a value or reference. + * \param message_severity The logging severity level of the message. + * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. + * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. + * \param ... Zero or more variadic arguments referenced by the format string. + */ +#define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \ + do { \ + if (message_severity >= logger.GetLoggingSeverityLevel()) { \ + Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \ + static_cast(__FUNCTION__), __VA_ARGS__)); \ + } \ + } while (false) - template - T* GetTensorMutableData(_Inout_ OrtValue* value); - template - const T* GetTensorData(_Inout_ const OrtValue* value); +/** + * Macro that logs a printf-like formatted message using the provided logger. Can be used in noexcept code since errors + * are silently ignored. + * Example: ORT_CXX_LOGF_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12); + * + * \param logger The Ort::Logger instance to use. Must be a value or reference. + * \param message_severity The logging severity level of the message. + * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. + * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. + * \param ... Zero or more variadic arguments referenced by the format string. + */ +#define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \ + do { \ + if (message_severity >= logger.GetLoggingSeverityLevel()) { \ + static_cast(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \ + static_cast(__FUNCTION__), __VA_ARGS__)); \ + } \ + } while (false) + +/// +/// This class represents an ONNX Runtime logger that can be used to log information with an +/// associated severity level and source code location (file path, line number, function name). +/// +/// A Logger can be obtained from within custom operators by calling Ort::KernelInfo::GetLogger(). +/// Instances of Ort::Logger are the size of two pointers and can be passed by value. +/// +/// Use the ORT_CXX_LOG macros to ensure the source code location is set properly from the callsite +/// and to take advantage of a cached logging severity level that can bypass calls to the underlying C API. +/// +struct Logger { + /** + * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use. + */ + Logger() = default; + + /** + * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use. + */ + explicit Logger(std::nullptr_t) {} + + /** + * Creates a logger from an ::OrtLogger instance. Caches the logger's current severity level by calling + * OrtApi::Logger_GetLoggingSeverityLevel. Throws an exception if OrtApi::Logger_GetLoggingSeverityLevel fails. + * + * \param logger The ::OrtLogger to wrap. + */ + explicit Logger(const OrtLogger* logger); + + ~Logger() = default; + + Logger(const Logger&) = default; + Logger& operator=(const Logger&) = default; + + Logger(Logger&& v) noexcept = default; + Logger& operator=(Logger&& v) noexcept = default; + + /** + * Returns the logger's current severity level from the cached member. + * + * \return The current ::OrtLoggingLevel. + */ + OrtLoggingLevel GetLoggingSeverityLevel() const noexcept; + + /** + * Logs the provided message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOG or ORT_CXX_LOG_NOEXCEPT + * macros to properly set the source code location and to use the cached severity level to potentially bypass + * calls to the underlying C API. + * + * \param log_severity_level The message's logging severity level. + * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. + * \param line_number The file line number in which the message is logged. Usually the value of __LINE__. + * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. + * \param message The message to log. + * \return A Ort::Status value to indicate error or success. + */ + Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, + const char* func_name, const char* message) const noexcept; + + /** + * Logs a printf-like formatted message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOGF or ORT_CXX_LOGF_NOEXCEPT + * macros to properly set the source code location and to use the cached severity level to potentially bypass + * calls to the underlying C API. Returns an error status if a formatting error occurs. + * + * \param log_severity_level The message's logging severity level. + * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. + * \param line_number The file line number in which the message is logged. Usually the value of __LINE__. + * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. + * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. + * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. + * \param args Zero or more variadic arguments referenced by the format string. + * \return A Ort::Status value to indicate error or success. + */ + template + Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, + const char* func_name, const char* format, Args&&... args) const noexcept; + + private: + const OrtLogger* logger_{}; + OrtLoggingLevel cached_severity_level_{}; +}; - const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value); +/// +/// This class wraps a raw pointer OrtKernelContext* that is being passed +/// to the custom kernel Compute() method. Use it to safely access context +/// attributes, input and output parameters with exception safety guarantees. +/// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +/// +struct KernelContext { + explicit KernelContext(OrtKernelContext* context); + size_t GetInputCount() const; + size_t GetOutputCount() const; + // If input is optional and is not present, the method returns en empty ConstValue + // which can be compared to nullptr. + ConstValue GetInput(size_t index) const; + // If outout is optional and is not present, the method returns en empty UnownedValue + // which can be compared to nullptr. + UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; + UnownedValue GetOutput(size_t index, const std::vector& dims) const; + void* GetGPUComputeStream() const; + Logger GetLogger() const; + OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const; + OrtKernelContext* GetOrtKernelContext() const { return ctx_; } + void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const; - std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info); - void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input); - size_t KernelContext_GetInputCount(const OrtKernelContext* context); - const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index); - size_t KernelContext_GetOutputCount(const OrtKernelContext* context); - OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count); - void* KernelContext_GetGPUComputeStream(const OrtKernelContext* context); + private: + OrtKernelContext* ctx_; +}; + +struct KernelInfo; + +namespace detail { +namespace attr_utils { +void GetAttr(const OrtKernelInfo* p, const char* name, float&); +void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&); +void GetAttr(const OrtKernelInfo* p, const char* name, std::string&); +void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); +void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); +} // namespace attr_utils + +template +struct KernelInfoImpl : Base { + using B = Base; + using B::B; + + KernelInfo Copy() const; + + template // R is only implemented for float, int64_t, and string + R GetAttribute(const char* name) const { + R val; + attr_utils::GetAttr(this->p_, name, val); + return val; + } + + template // R is only implemented for std::vector, std::vector + std::vector GetAttributes(const char* name) const { + std::vector result; + attr_utils::GetAttrs(this->p_, name, result); + return result; + } + + Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const; + + size_t GetInputCount() const; + size_t GetOutputCount() const; + + std::string GetInputName(size_t index) const; + std::string GetOutputName(size_t index) const; + + TypeInfo GetInputTypeInfo(size_t index) const; + TypeInfo GetOutputTypeInfo(size_t index) const; + + ConstValue GetTensorConstantInput(size_t index, int* is_constant) const; + + std::string GetNodeName() const; + Logger GetLogger() const; +}; + +} // namespace detail + +using ConstKernelInfo = detail::KernelInfoImpl>; + +/// +/// This struct owns the OrtKernInfo* pointer when a copy is made. +/// For convenient wrapping of OrtKernelInfo* passed to kernel constructor +/// and query attributes, warp the pointer with Ort::Unowned instance +/// so it does not destroy the pointer the kernel does not own. +/// +struct KernelInfo : detail::KernelInfoImpl { + explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later + explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance + ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; } +}; + +/// +/// Create and own custom defined operation. +/// +struct Op : detail::Base { + explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used + + explicit Op(OrtOp*); ///< Take ownership of the OrtOp + + static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain, + int version, const char** type_constraint_names, + const ONNXTensorElementDataType* type_constraint_values, + size_t type_constraint_count, + const OpAttr* attr_values, + size_t attr_count, + size_t input_count, size_t output_count); + + void Invoke(const OrtKernelContext* context, + const Value* input_values, + size_t input_count, + Value* output_values, + size_t output_count); + + // For easier refactoring + void Invoke(const OrtKernelContext* context, + const OrtValue* const* input_values, + size_t input_count, + OrtValue* const* output_values, + size_t output_count); +}; + +/// +/// Provide access to per-node attributes and input shapes, so one could compute and set output shapes. +/// +struct ShapeInferContext { + struct SymbolicInteger { + SymbolicInteger(int64_t i) : i_(i), is_int_(true) {}; + SymbolicInteger(const char* s) : s_(s), is_int_(false) {}; + SymbolicInteger(const SymbolicInteger&) = default; + SymbolicInteger(SymbolicInteger&&) = default; + + SymbolicInteger& operator=(const SymbolicInteger&) = default; + SymbolicInteger& operator=(SymbolicInteger&&) = default; + + bool operator==(const SymbolicInteger& dim) const { + if (is_int_ == dim.is_int_) { + if (is_int_) { + return i_ == dim.i_; + } else { + return std::string{s_} == std::string{dim.s_}; + } + } + return false; + } + + bool IsInt() const { return is_int_; } + int64_t AsInt() const { return i_; } + const char* AsSym() const { return s_; } + + static constexpr int INVALID_INT_DIM = -2; + + private: + union { + int64_t i_; + const char* s_; + }; + bool is_int_; + }; + + using Shape = std::vector; + + ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx); + + const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); } + + size_t GetInputCount() const { return input_shapes_.size(); } + + Status SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - void ThrowOnError(OrtStatus* result); + int64_t GetAttrInt(const char* attr_name); + + using Ints = std::vector; + Ints GetAttrInts(const char* attr_name); + + float GetAttrFloat(const char* attr_name); + + using Floats = std::vector; + Floats GetAttrFloats(const char* attr_name); + + std::string GetAttrString(const char* attr_name); + + using Strings = std::vector; + Strings GetAttrStrings(const char* attr_name); private: - const OrtApi& api_; + const OrtOpAttr* GetAttrHdl(const char* attr_name) const; + const OrtApi* ort_api_; + OrtShapeInferContext* ctx_; + std::vector input_shapes_; }; -template +using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&); + +#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1 + +template struct CustomOpBase : OrtCustomOp { CustomOpBase() { OrtCustomOp::version = ORT_API_VERSION; - OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast(this_)->GetName(); }; OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast(this_)->GetExecutionProviderType(); }; OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetInputTypeCount(); }; OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputType(index); }; + OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputMemoryType(index); }; OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetOutputTypeCount(); }; OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputType(index); }; - OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast(op_kernel)->Compute(context); }; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning(disable : 26409) +#endif OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast(op_kernel); }; - +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputCharacteristic(index); }; OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputCharacteristic(index); }; + + OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast(this_)->GetVariadicInputMinArity(); }; + OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast(static_cast(this_)->GetVariadicInputHomogeneity()); }; + OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast(this_)->GetVariadicOutputMinArity(); }; + OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast(static_cast(this_)->GetVariadicOutputHomogeneity()); }; +#ifdef __cpp_if_constexpr + if constexpr (WithStatus) { +#else + if (WithStatus) { +#endif + OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr { + return static_cast(this_)->CreateKernelV2(*api, info, op_kernel); + }; + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + return static_cast(op_kernel)->ComputeV2(context); + }; + } else { + OrtCustomOp::CreateKernelV2 = nullptr; + OrtCustomOp::KernelComputeV2 = nullptr; + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { + static_cast(op_kernel)->Compute(context); + }; + } + + SetShapeInferFn(0); + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->end_ver_; + }; + + OrtCustomOp::GetMayInplace = nullptr; + OrtCustomOp::ReleaseMayInplace = nullptr; + OrtCustomOp::GetAliasMap = nullptr; + OrtCustomOp::ReleaseAliasMap = nullptr; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider @@ -968,6 +2371,63 @@ struct CustomOpBase : OrtCustomOp { OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; } + + // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault + OrtMemType GetInputMemoryType(size_t /*index*/) const { + return OrtMemTypeDefault; + } + + // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input + // should expect at least 1 argument. + int GetVariadicInputMinArity() const { + return 1; + } + + // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments + // to a variadic input should be of the same type. + bool GetVariadicInputHomogeneity() const { + return true; + } + + // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output + // should produce at least 1 output value. + int GetVariadicOutputMinArity() const { + return 1; + } + + // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values + // produced by a variadic output should be of the same type. + bool GetVariadicOutputHomogeneity() const { + return true; + } + + // Declare list of session config entries used by this Custom Op. + // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs(). + // This default implementation returns an empty vector of config entries. + std::vector GetSessionConfigKeys() const { + return std::vector{}; + } + + template + decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + ShapeInferContext ctx(&GetApi(), ort_ctx); + return C::InferOutputShape(ctx); + }; + return {}; + } + + template + void SetShapeInferFn(...) { + OrtCustomOp::InferOutputShapeFn = {}; + } + + protected: + // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys. + void GetSessionConfigs(std::unordered_map& out, ConstSessionOptions options) const; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; }; } // namespace Ort diff --git a/libs/onnxruntime/include/onnxruntime_cxx_inline.h b/libs/onnxruntime/include/onnxruntime_cxx_inline.h index 62f63e7..f1f4904 100644 --- a/libs/onnxruntime/include/onnxruntime_cxx_inline.h +++ b/libs/onnxruntime/include/onnxruntime_cxx_inline.h @@ -7,50 +7,174 @@ // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter // the main C++ file with implementation details. +#include +#include +#include +#include + +// Convert OrtStatus to Ort::Status and return +// instead of throwing +#define ORT_CXX_RETURN_ON_API_FAIL(expression) \ + { \ + auto ort_status = (expression); \ + if (ort_status) { \ + return Ort::Status(ort_status); \ + } \ + } + +#ifdef __cpp_if_constexpr +#define ORT_CXX_IF_CONSTEXPR if constexpr +#else +#define ORT_CXX_IF_CONSTEXPR if +#endif + namespace Ort { -inline void ThrowOnError(const OrtApi& ort, OrtStatus* status) { - if (status) { - std::string error_message = ort.GetErrorMessage(status); - OrtErrorCode error_code = ort.GetErrorCode(status); - ort.ReleaseStatus(status); - ORT_CXX_API_THROW(std::move(error_message), error_code); +namespace detail { +inline void ThrowStatus(const Status& st) { + std::string error_message = st.GetErrorMessage(); + OrtErrorCode error_code = st.GetErrorCode(); + ORT_CXX_API_THROW(std::move(error_message), error_code); +} +} // namespace detail + +inline void ThrowOnError(OrtStatus* ort_status) { + if (ort_status) { + Ort::Status st(ort_status); + detail::ThrowStatus(st); + } +} + +inline void ThrowOnError(const Status& st) { + if (st) { + detail::ThrowStatus(st); } } -inline void ThrowOnError(OrtStatus* status) { - ThrowOnError(GetApi(), status); +inline Status::Status(OrtStatus* status) noexcept : Base{status} { +} + +inline Status::Status(const std::exception& e) noexcept { + p_ = GetApi().CreateStatus(ORT_FAIL, e.what()); +} + +inline Status::Status(const Exception& e) noexcept { + p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what()); +} + +inline Status::Status(const char* message, OrtErrorCode code) noexcept { + p_ = GetApi().CreateStatus(code, message); +} + +inline std::string Status::GetErrorMessage() const { + std::string message(GetApi().GetErrorMessage(p_)); + return message; +} + +inline OrtErrorCode Status::GetErrorCode() const { + return GetApi().GetErrorCode(p_); +} + +inline bool Status::IsOK() const noexcept { + return (p_ == nullptr); } // This template converts a C++ type into it's ONNXTensorElementDataType template struct TypeToTensorType; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; +}; + +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ; +}; template <> -struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; }; +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2; +}; +template <> +struct TypeToTensorType { + static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ; +}; + +inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept { + if (IsNaN() || rhs.IsNaN()) { + // IEEE defines that NaN is not equal to anything, including itself. + return false; + } + return val == rhs.val; +} + +inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept { + if (IsNaN() || rhs.IsNaN()) { + // IEEE defines that NaN is unordered with respect to everything, including itself. + return false; + } + + const bool left_is_negative = IsNegative(); + if (left_is_negative != rhs.IsNegative()) { + // When the signs of left and right differ, we know that left is less than right if it is + // the negative value. The exception to this is if both values are zero, in which case IEEE + // says they should be equal, even if the signs differ. + return left_is_negative && !AreZero(*this, rhs); + } + return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); +} inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size) : allocator_(allocator), p_(p), size_(size) { @@ -88,63 +212,92 @@ inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexc return *this; } -inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() { - ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&p_)); -} +namespace detail { -inline void* AllocatorWithDefaultOptions::Alloc(size_t size) { +template +inline void* AllocatorImpl::Alloc(size_t size) { void* out; - ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out)); + ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out)); return out; } -inline MemoryAllocation Ort::AllocatorWithDefaultOptions::GetAllocation(size_t size) { +template +inline MemoryAllocation AllocatorImpl::GetAllocation(size_t size) { void* out; - ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out)); - MemoryAllocation result(p_, out, size); + ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out)); + MemoryAllocation result(this->p_, out, size); return result; } -inline void AllocatorWithDefaultOptions::Free(void* p) { - ThrowOnError(GetApi().AllocatorFree(p_, p)); +template +inline void AllocatorImpl::Free(void* p) { + ThrowOnError(GetApi().AllocatorFree(this->p_, p)); } -inline const OrtMemoryInfo* AllocatorWithDefaultOptions::GetInfo() const { +template +inline ConstMemoryInfo AllocatorImpl::GetInfo() const { const OrtMemoryInfo* out; - ThrowOnError(GetApi().AllocatorGetInfo(p_, &out)); - return out; + ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out)); + return ConstMemoryInfo{out}; +} + +} // namespace detail + +inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() { + ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_)); +} + +inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) { + ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_)); } -inline std::string MemoryInfo::GetAllocatorName() const { +namespace detail { + +template +inline std::string MemoryInfoImpl::GetAllocatorName() const { const char* name = nullptr; - ThrowOnError(GetApi().MemoryInfoGetName(*this, &name)); + ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name)); return std::string(name); } -inline OrtAllocatorType MemoryInfo::GetAllocatorType() const { +template +inline OrtAllocatorType MemoryInfoImpl::GetAllocatorType() const { OrtAllocatorType type; - ThrowOnError(GetApi().MemoryInfoGetType(*this, &type)); + ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type)); return type; } -inline int MemoryInfo::GetDeviceId() const { +template +inline int MemoryInfoImpl::GetDeviceId() const { int id = 0; - ThrowOnError(GetApi().MemoryInfoGetId(*this, &id)); + ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id)); return id; } -inline OrtMemType MemoryInfo::GetMemoryType() const { +template +inline OrtMemoryInfoDeviceType MemoryInfoImpl::GetDeviceType() const { + OrtMemoryInfoDeviceType type; + GetApi().MemoryInfoGetDeviceType(this->p_, &type); + return type; +} + +template +inline OrtMemType MemoryInfoImpl::GetMemoryType() const { OrtMemType type; - ThrowOnError(GetApi().MemoryInfoGetMemType(*this, &type)); + ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type)); return type; } -inline bool MemoryInfo::operator==(const MemoryInfo& o) const { +template +template +inline bool MemoryInfoImpl::operator==(const MemoryInfoImpl& o) const { int comp_result = 0; - ThrowOnError(Ort::GetApi().CompareMemoryInfo(*this, o, &comp_result)); + ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result)); return comp_result == 0; } +} // namespace detail + inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) { OrtMemoryInfo* p; ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p)); @@ -152,61 +305,77 @@ inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_ty } inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { - ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &p_)); + ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_)); } -inline Allocator::Allocator(const Session& sess, const MemoryInfo& mem_info) { - ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &p_)); +namespace detail { +template +inline std::vector ConstIoBindingImpl::GetOutputNames() const { + AllocatorWithDefaultOptions allocator; + return binding_utils::GetOutputNamesHelper(this->p_, allocator); } -inline void* Allocator::Alloc(size_t size) const { - void* out = nullptr; - ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out)); - return out; +template +inline std::vector ConstIoBindingImpl::GetOutputNames(OrtAllocator* allocator) const { + return binding_utils::GetOutputNamesHelper(this->p_, allocator); } -inline MemoryAllocation Ort::Allocator::GetAllocation(size_t size) { - void* out = nullptr; - ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out)); - MemoryAllocation result(p_, out, size); - return result; +template +inline std::vector ConstIoBindingImpl::GetOutputValues() const { + AllocatorWithDefaultOptions allocator; + return binding_utils::GetOutputValuesHelper(this->p_, allocator); } -inline void Allocator::Free(void* p) const { - ThrowOnError(GetApi().AllocatorFree(p_, p)); +template +inline std::vector ConstIoBindingImpl::GetOutputValues(OrtAllocator* allocator) const { + return binding_utils::GetOutputValuesHelper(this->p_, allocator); } -inline Unowned Allocator::GetInfo() const { - const OrtMemoryInfo* out = nullptr; - ThrowOnError(GetApi().AllocatorGetInfo(p_, &out)); - return Unowned(const_cast(out)); +template +inline void IoBindingImpl::BindInput(const char* name, const Value& value) { + ThrowOnError(GetApi().BindInput(this->p_, name, value)); } -inline IoBinding::IoBinding(Session& session) { - ThrowOnError(GetApi().CreateIoBinding(session, &p_)); +template +inline void IoBindingImpl::BindOutput(const char* name, const Value& value) { + ThrowOnError(GetApi().BindOutput(this->p_, name, value)); +} + +template +inline void IoBindingImpl::BindOutput(const char* name, const OrtMemoryInfo* mem_info) { + ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info)); +} + +template +inline void IoBindingImpl::ClearBoundInputs() { + GetApi().ClearBoundInputs(this->p_); } -inline void IoBinding::BindInput(const char* name, const Value& value) { - ThrowOnError(GetApi().BindInput(p_, name, value)); +template +inline void IoBindingImpl::ClearBoundOutputs() { + GetApi().ClearBoundOutputs(this->p_); } -inline void IoBinding::BindOutput(const char* name, const Value& value) { - ThrowOnError(GetApi().BindOutput(p_, name, value)); +template +inline void IoBindingImpl::SynchronizeInputs() { + ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_)); } -inline void IoBinding::BindOutput(const char* name, const MemoryInfo& mem_info) { - ThrowOnError(GetApi().BindOutputToDevice(p_, name, mem_info)); +template +inline void IoBindingImpl::SynchronizeOutputs() { + ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_)); } -inline std::vector IoBinding::GetOutputNamesHelper(OrtAllocator* allocator) const { +namespace binding_utils { +inline std::vector GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { std::vector result; - auto free_fn = [allocator](void* p) { if (p) allocator->Free(allocator, p); }; + auto free_fn = detail::AllocatedFree(allocator); using Ptr = std::unique_ptr; char* buffer = nullptr; size_t* lengths = nullptr; size_t count = 0; - ThrowOnError(GetApi().GetBoundOutputNames(p_, allocator, &buffer, &lengths, &count)); + ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count)); if (count == 0) { return result; @@ -225,16 +394,7 @@ inline std::vector IoBinding::GetOutputNamesHelper(OrtAllocator* al return result; } -inline std::vector IoBinding::GetOutputNames() const { - AllocatorWithDefaultOptions allocator; - return GetOutputNamesHelper(allocator); -} - -inline std::vector IoBinding::GetOutputNames(Allocator& allocator) const { - return GetOutputNamesHelper(allocator); -} - -inline std::vector Ort::IoBinding::GetOutputValuesHelper(OrtAllocator* allocator) const { +inline std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { std::vector result; size_t owned = 0; size_t output_count = 0; @@ -252,7 +412,7 @@ inline std::vector Ort::IoBinding::GetOutputValuesHelper(OrtAllocator* al using Ptr = std::unique_ptr; OrtValue** output_buffer = nullptr; - ThrowOnError(GetApi().GetBoundOutputValues(p_, allocator, &output_buffer, &output_count)); + ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count)); if (output_count == 0) { return result; } @@ -267,33 +427,54 @@ inline std::vector Ort::IoBinding::GetOutputValuesHelper(OrtAllocator* al return result; } -inline std::vector Ort::IoBinding::GetOutputValues(Allocator& allocator) const { - return GetOutputValuesHelper(allocator); +} // namespace binding_utils +} // namespace detail + +inline IoBinding::IoBinding(Session& session) { + ThrowOnError(GetApi().CreateIoBinding(session, &this->p_)); } -inline std::vector Ort::IoBinding::GetOutputValues() const { - AllocatorWithDefaultOptions allocator; - return GetOutputValuesHelper(allocator); +inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) { + ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_)); } -inline void IoBinding::ClearBoundInputs() { - GetApi().ClearBoundInputs(p_); +inline ThreadingOptions::ThreadingOptions() { + ThrowOnError(GetApi().CreateThreadingOptions(&p_)); } -inline void IoBinding::ClearBoundOutputs() { - GetApi().ClearBoundOutputs(p_); +inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) { + ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads)); + return *this; } -inline void IoBinding::SynchronizeInputs() { - ThrowOnError(GetApi().SynchronizeBoundInputs(p_)); +inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) { + ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads)); + return *this; } -inline void IoBinding::SynchronizeOutputs() { - ThrowOnError(GetApi().SynchronizeBoundOutputs(p_)); +inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) { + ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning)); + return *this; } -inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) { - ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_)); +inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() { + ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { + ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) { + ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options)); + return *this; +} + +inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) { + ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn)); + return *this; } inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) { @@ -343,19 +524,53 @@ inline Env& Env::DisableTelemetryEvents() { return *this; } +inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) { + ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level)); + return *this; +} + inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) { ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg)); return *this; } +inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg) { + std::vector keys, values; + auto num_entries = options.size(); + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + for (const auto& entry : options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + ThrowOnError(GetApi().CreateAndRegisterAllocatorV2(p_, provider_type.c_str(), mem_info, arena_cfg, keys.data(), values.data(), num_entries)); + return *this; +} + inline CustomOpDomain::CustomOpDomain(const char* domain) { ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_)); } -inline void CustomOpDomain::Add(OrtCustomOp* op) { +inline void CustomOpDomain::Add(const OrtCustomOp* op) { ThrowOnError(GetApi().CustomOpDomain_Add(p_, op)); } +inline LoraAdapter LoraAdapter::CreateLoraAdapter(const std::basic_string& adapter_path, + OrtAllocator* allocator) { + OrtLoraAdapter* p; + ThrowOnError(GetApi().CreateLoraAdapter(adapter_path.c_str(), allocator, &p)); + return LoraAdapter{p}; +} + +inline LoraAdapter LoraAdapter::CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes, + OrtAllocator* allocator) { + OrtLoraAdapter* p; + ThrowOnError(GetApi().CreateLoraAdapterFromArray(bytes, num_bytes, allocator, &p)); + return LoraAdapter{p}; +} + inline RunOptions::RunOptions() { ThrowOnError(GetApi().CreateRunOptions(&p_)); } @@ -408,454 +623,954 @@ inline RunOptions& RunOptions::UnsetTerminate() { return *this; } -inline SessionOptions::SessionOptions() { - ThrowOnError(GetApi().CreateSessionOptions(&p_)); +inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter) { + ThrowOnError(GetApi().RunOptionsAddActiveLoraAdapter(p_, adapter)); + return *this; } -inline SessionOptions SessionOptions::Clone() const { +namespace detail { + +template +inline Ort::SessionOptions ConstSessionOptionsImpl::Clone() const { OrtSessionOptions* out; - ThrowOnError(GetApi().CloneSessionOptions(p_, &out)); + ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out)); return SessionOptions{out}; } -inline SessionOptions& SessionOptions::SetIntraOpNumThreads(int intra_op_num_threads) { - ThrowOnError(GetApi().SetIntraOpNumThreads(p_, intra_op_num_threads)); - return *this; +template +inline std::string ConstSessionOptionsImpl::GetConfigEntry(const char* config_key) const { + size_t size = 0; + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + + return out; } -inline SessionOptions& SessionOptions::SetInterOpNumThreads(int inter_op_num_threads) { - ThrowOnError(GetApi().SetInterOpNumThreads(p_, inter_op_num_threads)); +template +inline bool ConstSessionOptionsImpl::HasConfigEntry(const char* config_key) const { + int out = 0; + Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out)); + return static_cast(out); +} + +template +inline std::string ConstSessionOptionsImpl::GetConfigEntryOrDefault(const char* config_key, const std::string& def) { + if (!this->HasConfigEntry(config_key)) { + return def; + } + + return this->GetConfigEntry(config_key); +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetIntraOpNumThreads(int intra_op_num_threads) { + ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads)); return *this; } -inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { - ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(p_, graph_optimization_level)); +template +inline SessionOptionsImpl& SessionOptionsImpl::SetInterOpNumThreads(int inter_op_num_threads) { + ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads)); return *this; } -inline SessionOptions& SessionOptions::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) { - ThrowOnError(GetApi().SetOptimizedModelFilePath(p_, optimized_model_filepath)); +template +inline SessionOptionsImpl& SessionOptionsImpl::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { + ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level)); return *this; } -inline SessionOptions& SessionOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) { - ThrowOnError(GetApi().EnableProfiling(p_, profile_file_prefix)); +template +inline SessionOptionsImpl& SessionOptionsImpl::SetDeterministicCompute(bool value) { + ThrowOnError(GetApi().SetDeterministicCompute(this->p_, value)); return *this; } -inline SessionOptions& SessionOptions::DisableProfiling() { - ThrowOnError(GetApi().DisableProfiling(p_)); +template +inline SessionOptionsImpl& SessionOptionsImpl::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) { + ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath)); return *this; } -inline SessionOptions& SessionOptions::EnableOrtCustomOps() { - ThrowOnError(GetApi().EnableOrtCustomOps(p_)); +template +inline SessionOptionsImpl& SessionOptionsImpl::EnableProfiling(const ORTCHAR_T* profile_file_prefix) { + ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix)); return *this; } -inline SessionOptions& SessionOptions::EnableMemPattern() { - ThrowOnError(GetApi().EnableMemPattern(p_)); +template +inline SessionOptionsImpl& SessionOptionsImpl::DisableProfiling() { + ThrowOnError(GetApi().DisableProfiling(this->p_)); return *this; } -inline SessionOptions& SessionOptions::DisableMemPattern() { - ThrowOnError(GetApi().DisableMemPattern(p_)); +template +inline SessionOptionsImpl& SessionOptionsImpl::EnableOrtCustomOps() { + ThrowOnError(GetApi().EnableOrtCustomOps(this->p_)); return *this; } -inline SessionOptions& SessionOptions::EnableCpuMemArena() { - ThrowOnError(GetApi().EnableCpuMemArena(p_)); +template +inline SessionOptionsImpl& SessionOptionsImpl::EnableMemPattern() { + ThrowOnError(GetApi().EnableMemPattern(this->p_)); return *this; } -inline SessionOptions& SessionOptions::DisableCpuMemArena() { - ThrowOnError(GetApi().DisableCpuMemArena(p_)); +template +inline SessionOptionsImpl& SessionOptionsImpl::DisableMemPattern() { + ThrowOnError(GetApi().DisableMemPattern(this->p_)); return *this; } -inline SessionOptions& SessionOptions::SetExecutionMode(ExecutionMode execution_mode) { - ThrowOnError(GetApi().SetSessionExecutionMode(p_, execution_mode)); +template +inline SessionOptionsImpl& SessionOptionsImpl::EnableCpuMemArena() { + ThrowOnError(GetApi().EnableCpuMemArena(this->p_)); return *this; } -inline SessionOptions& SessionOptions::SetLogId(const char* logid) { - ThrowOnError(GetApi().SetSessionLogId(p_, logid)); +template +inline SessionOptionsImpl& SessionOptionsImpl::DisableCpuMemArena() { + ThrowOnError(GetApi().DisableCpuMemArena(this->p_)); return *this; } -inline SessionOptions& SessionOptions::SetLogSeverityLevel(int level) { - ThrowOnError(GetApi().SetSessionLogSeverityLevel(p_, level)); +template +inline SessionOptionsImpl& SessionOptionsImpl::SetExecutionMode(ExecutionMode execution_mode) { + ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode)); return *this; } -inline SessionOptions& SessionOptions::Add(OrtCustomOpDomain* custom_op_domain) { - ThrowOnError(GetApi().AddCustomOpDomain(p_, custom_op_domain)); +template +inline SessionOptionsImpl& SessionOptionsImpl::SetLogId(const char* logid) { + ThrowOnError(GetApi().SetSessionLogId(this->p_, logid)); return *this; } -inline SessionOptions& SessionOptions::AddConfigEntry(const char* config_key, const char* config_value) { - ThrowOnError(GetApi().AddSessionConfigEntry(p_, config_key, config_value)); +template +inline SessionOptionsImpl& SessionOptionsImpl::SetLogSeverityLevel(int level) { + ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level)); return *this; } -inline SessionOptions& SessionOptions::AddInitializer(const char* name, const OrtValue* ort_val) { - ThrowOnError(GetApi().AddInitializer(p_, name, ort_val)); +template +inline SessionOptionsImpl& SessionOptionsImpl::Add(OrtCustomOpDomain* custom_op_domain) { + ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain)); return *this; } -inline SessionOptions& SessionOptions::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(p_, &provider_options)); +template +inline SessionOptionsImpl& SessionOptionsImpl::AddConfigEntry(const char* config_key, const char* config_value) { + ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value)); return *this; } -inline SessionOptions& SessionOptions::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(p_, &provider_options)); +template +inline SessionOptionsImpl& SessionOptionsImpl::AddInitializer(const char* name, const OrtValue* ort_val) { + ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val)); return *this; } -inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(p_, &provider_options)); +template +inline SessionOptionsImpl& SessionOptionsImpl::DisablePerSessionThreads() { + ThrowOnError(GetApi().DisablePerSessionThreads(this->p_)); return *this; } -inline SessionOptions& SessionOptions::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { - ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(p_, ort_custom_create_thread_fn)); +template +inline SessionOptionsImpl& SessionOptionsImpl::AddExternalInitializers(const std::vector& names, + const std::vector& ort_values) { + const size_t inputs_num = names.size(); + if (inputs_num != ort_values.size()) { + ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT); + } + std::vector names_ptr; + std::vector ort_values_ptrs; + names_ptr.reserve(inputs_num); + ort_values_ptrs.reserve(inputs_num); + for (size_t i = 0; i < inputs_num; ++i) { + names_ptr.push_back(names[i].c_str()); + ort_values_ptrs.push_back(ort_values[i]); + } + ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num)); return *this; } -inline SessionOptions& SessionOptions::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) { - ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(p_, ort_custom_thread_creation_options)); +template +inline SessionOptionsImpl& SessionOptionsImpl::AddExternalInitializersFromFilesInMemory(const std::vector>& file_names, + const std::vector& buffer_array, + const std::vector& file_lengths) { + const size_t inputs_num = file_names.size(); + if (inputs_num != buffer_array.size()) { + ORT_CXX_API_THROW("Expecting names and buffer_array to have the same length", ORT_INVALID_ARGUMENT); + } + if (inputs_num != file_lengths.size()) { + ORT_CXX_API_THROW("Expecting names and file_lengths to have the same length", ORT_INVALID_ARGUMENT); + } + std::vector names_ptr; + names_ptr.reserve(inputs_num); + for (size_t i = 0; i < inputs_num; ++i) { + names_ptr.push_back(file_names[i].c_str()); + } + ThrowOnError(GetApi().AddExternalInitializersFromFilesInMemory(this->p_, names_ptr.data(), buffer_array.data(), + file_lengths.data(), inputs_num)); return *this; } -inline SessionOptions& SessionOptions::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) { - ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(p_, ort_custom_join_thread_fn)); +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options)); return *this; } -inline SessionOptions& SessionOptions::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(p_, &provider_options)); +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options)); return *this; } -inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { - ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_)); +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options)); + return *this; } -inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container) { - ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &p_)); +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options)); + return *this; } -inline Session::Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) { - ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &p_)); +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options)); + return *this; } -inline std::vector Session::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, size_t output_names_count) { - std::vector output_values; - for (size_t i = 0; i < output_names_count; i++) - output_values.emplace_back(nullptr); - Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_names_count); - return output_values; +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options)); + return *this; } -inline void Session::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, Value* output_values, size_t output_count) { - static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); - auto ort_input_values = reinterpret_cast(const_cast(input_values)); - auto ort_output_values = reinterpret_cast(output_values); - ThrowOnError(GetApi().Run(p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values)); +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options)); + return *this; } -inline void Session::Run(const RunOptions& run_options, const IoBinding& io_binding) { - ThrowOnError(GetApi().RunWithBinding(p_, run_options, io_binding)); +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options)); + return *this; } -inline size_t Session::GetInputCount() const { - size_t out; - ThrowOnError(GetApi().SessionGetInputCount(p_, &out)); - return out; +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider( + const std::string& provider_name, + const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(), + keys.data(), values.data(), num_entries)); + + return *this; } -inline size_t Session::GetOutputCount() const { - size_t out; - ThrowOnError(GetApi().SessionGetOutputCount(p_, &out)); - return out; +template +inline SessionOptionsImpl& SessionOptionsImpl::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { + ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn)); + return *this; } -inline size_t Session::GetOverridableInitializerCount() const { - size_t out; - ThrowOnError(GetApi().SessionGetOverridableInitializerCount(p_, &out)); - return out; +template +inline SessionOptionsImpl& SessionOptionsImpl::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) { + ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options)); + return *this; } -inline char* Session::GetInputName(size_t index, OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().SessionGetInputName(p_, index, allocator, &out)); - return out; +template +inline SessionOptionsImpl& SessionOptionsImpl::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) { + ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn)); + return *this; } -inline char* Session::GetOutputName(size_t index, OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().SessionGetOutputName(p_, index, allocator, &out)); - return out; +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options)); + return *this; } -inline char* Session::GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().SessionGetOverridableInitializerName(p_, index, allocator, &out)); - return out; +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_OpenVINO_V2(const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO_V2(this->p_, + keys.data(), values.data(), num_entries)); + + return *this; } -inline char* Session::EndProfiling(OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().SessionEndProfiling(p_, allocator, &out)); - return out; +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_VitisAI(this->p_, keys.data(), values.data(), num_entries)); + + return *this; } -inline uint64_t Session::GetProfilingStartTimeNs() const { - uint64_t out; - ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(p_, &out)); - return out; +template +inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, + const CustomOpConfigs& custom_op_configs) { + // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by + // the custom op library. + for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) { + AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str()); + } + + ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name)); + return *this; } -inline ModelMetadata Session::GetModelMetadata() const { - OrtModelMetadata* out; - ThrowOnError(GetApi().SessionGetModelMetadata(p_, &out)); - return ModelMetadata{out}; +template +inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsUsingFunction(const char* registration_function_name) { + ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name)); + return *this; } -inline char* ModelMetadata::GetProducerName(OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out)); +/// Session +template +inline size_t ConstSessionImpl::GetInputCount() const { + size_t out; + ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out)); return out; } -inline char* ModelMetadata::GetGraphName(OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out)); +template +inline size_t ConstSessionImpl::GetOutputCount() const { + size_t out; + ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out)); return out; } -inline char* ModelMetadata::GetDomain(OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out)); +template +inline size_t ConstSessionImpl::GetOverridableInitializerCount() const { + size_t out; + ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out)); return out; } -inline char* ModelMetadata::GetDescription(OrtAllocator* allocator) const { +template +inline AllocatedStringPtr ConstSessionImpl::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { char* out; - ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out)); - return out; + ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); } -inline char* ModelMetadata::GetGraphDescription(OrtAllocator* allocator) const { +template +inline AllocatedStringPtr ConstSessionImpl::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const { char* out; - ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out)); - return out; + ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); } -inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const { +template +inline AllocatedStringPtr ConstSessionImpl::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const { char* out; - ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out)); - return out; + ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); } -inline char** ModelMetadata::GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const { - char** out; - ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys)); +template +inline uint64_t ConstSessionImpl::GetProfilingStartTimeNs() const { + uint64_t out; + ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out)); return out; } -inline int64_t ModelMetadata::GetVersion() const { - int64_t out; - ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out)); - return out; +template +inline ModelMetadata ConstSessionImpl::GetModelMetadata() const { + OrtModelMetadata* out; + ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out)); + return ModelMetadata{out}; } -inline TypeInfo Session::GetInputTypeInfo(size_t index) const { +template +inline TypeInfo ConstSessionImpl::GetInputTypeInfo(size_t index) const { OrtTypeInfo* out; - ThrowOnError(GetApi().SessionGetInputTypeInfo(p_, index, &out)); + ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out)); return TypeInfo{out}; } -inline TypeInfo Session::GetOutputTypeInfo(size_t index) const { +template +inline TypeInfo ConstSessionImpl::GetOutputTypeInfo(size_t index) const { OrtTypeInfo* out; - ThrowOnError(GetApi().SessionGetOutputTypeInfo(p_, index, &out)); + ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out)); return TypeInfo{out}; } -inline TypeInfo Session::GetOverridableInitializerTypeInfo(size_t index) const { +template +inline TypeInfo ConstSessionImpl::GetOverridableInitializerTypeInfo(size_t index) const { OrtTypeInfo* out; - ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(p_, index, &out)); + ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out)); return TypeInfo{out}; } -inline ONNXTensorElementDataType TensorTypeAndShapeInfo::GetElementType() const { +template +inline std::vector SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, size_t output_count) { + std::vector output_values; + output_values.reserve(output_count); + for (size_t i = 0; i < output_count; i++) + output_values.emplace_back(nullptr); + Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count); + return output_values; +} + +template +inline void SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count) { + static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); + auto ort_input_values = reinterpret_cast(input_values); + auto ort_output_values = reinterpret_cast(output_values); + ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values)); +} + +template +inline void SessionImpl::Run(const RunOptions& run_options, const IoBinding& io_binding) { + ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding)); +} + +template +inline void SessionImpl::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) { + auto ort_input_values = reinterpret_cast(input_values); + auto ort_output_values = reinterpret_cast(output_values); + ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names, + ort_input_values, input_count, output_names, output_count, + ort_output_values, callback, user_data)); +} + +template +inline AllocatedStringPtr SessionImpl::EndProfilingAllocated(OrtAllocator* allocator) { + char* out = nullptr; + ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +template +inline void SessionImpl::SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len) { + ThrowOnError(GetApi().SetEpDynamicOptions(this->p_, keys, values, kv_len)); +} + +} // namespace detail + +inline SessionOptions::SessionOptions() { + ThrowOnError(GetApi().CreateSessionOptions(&this->p_)); +} + +/// CustomOpConfigs +inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) { + std::string config_key = "custom_op."; + + config_key += custom_op_name; + config_key += "."; + config_key += config; + + return config_key; +} + +inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) { + const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key); + flat_configs_[full_flat_key] = config_value; + return *this; +} + +inline const std::unordered_map& CustomOpConfigs::GetFlattenedConfigs() const { + return flat_configs_; +} + +inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { + ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_)); +} + +inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container) { + ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_)); +} + +inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) { + ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_)); +} + +inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, + const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) { + ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options, + prepacked_weights_container, &this->p_)); +} + +inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out)); + return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); +} + +inline std::vector ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const { + auto deletor = detail::AllocatedFree(allocator); + std::vector result; + + char** out = nullptr; + int64_t num_keys = 0; + ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys)); + if (num_keys <= 0) { + return result; + } + + // array of pointers will be freed + std::unique_ptr array_guard(out, deletor); + // reserve may throw + auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); }; + std::unique_ptr strings_guard(out, strings_deletor); + result.reserve(static_cast(num_keys)); + strings_guard.release(); + for (int64_t i = 0; i < num_keys; ++i) { + result.push_back(AllocatedStringPtr(out[i], deletor)); + } + + return result; +} + +inline int64_t ModelMetadata::GetVersion() const { + int64_t out; + ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out)); + return out; +} + +namespace detail { + +template +inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl::GetElementType() const { ONNXTensorElementDataType out; - ThrowOnError(GetApi().GetTensorElementType(p_, &out)); + ThrowOnError(GetApi().GetTensorElementType(this->p_, &out)); return out; } -inline size_t TensorTypeAndShapeInfo::GetElementCount() const { +template +inline size_t TensorTypeAndShapeInfoImpl::GetElementCount() const { size_t out; - ThrowOnError(GetApi().GetTensorShapeElementCount(p_, &out)); + ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out)); return static_cast(out); } -inline size_t TensorTypeAndShapeInfo::GetDimensionsCount() const { +template +inline size_t TensorTypeAndShapeInfoImpl::GetDimensionsCount() const { size_t out; - ThrowOnError(GetApi().GetDimensionsCount(p_, &out)); + ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out)); return out; } -inline void TensorTypeAndShapeInfo::GetDimensions(int64_t* values, size_t values_count) const { - ThrowOnError(GetApi().GetDimensions(p_, values, values_count)); +template +inline void TensorTypeAndShapeInfoImpl::GetDimensions(int64_t* values, size_t values_count) const { + ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count)); } -inline void TensorTypeAndShapeInfo::GetSymbolicDimensions(const char** values, size_t values_count) const { - ThrowOnError(GetApi().GetSymbolicDimensions(p_, values, values_count)); +template +inline void TensorTypeAndShapeInfoImpl::GetSymbolicDimensions(const char** values, size_t values_count) const { + ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count)); } -inline std::vector TensorTypeAndShapeInfo::GetShape() const { +template +inline std::vector TensorTypeAndShapeInfoImpl::GetShape() const { std::vector out(GetDimensionsCount(), 0); - GetDimensions(out.data(), out.size()); + ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size())); return out; } -inline Unowned TypeInfo::GetTensorTypeAndShapeInfo() const { +template +inline ConstTensorTypeAndShapeInfo TypeInfoImpl::GetTensorTypeAndShapeInfo() const { const OrtTensorTypeAndShapeInfo* out; - ThrowOnError(GetApi().CastTypeInfoToTensorInfo(p_, &out)); - return Unowned(const_cast(out)); + ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out)); + return ConstTensorTypeAndShapeInfo{out}; } -inline Unowned TypeInfo::GetSequenceTypeInfo() const { +template +inline ConstSequenceTypeInfo TypeInfoImpl::GetSequenceTypeInfo() const { const OrtSequenceTypeInfo* out; - ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(p_, &out)); - return Unowned{const_cast(out)}; + ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out)); + return ConstSequenceTypeInfo{out}; +} + +template +inline ConstMapTypeInfo TypeInfoImpl::GetMapTypeInfo() const { + const OrtMapTypeInfo* out; + ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out)); + return ConstMapTypeInfo{out}; +} + +template +inline ONNXType TypeInfoImpl::GetONNXType() const { + ONNXType out; + ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out)); + return out; } -inline TypeInfo SequenceTypeInfo::GetSequenceElementType() const { +template +inline TypeInfo SequenceTypeInfoImpl::GetSequenceElementType() const { OrtTypeInfo* output; - ThrowOnError(GetApi().GetSequenceElementType(p_, &output)); + ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output)); return TypeInfo{output}; } -inline Unowned TypeInfo::GetMapTypeInfo() const { - const OrtMapTypeInfo* out; - ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(p_, &out)); - return Unowned{const_cast(out)}; +template +inline TypeInfo OptionalTypeInfoImpl::GetOptionalElementType() const { + OrtTypeInfo* info; + ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info)); + return TypeInfo{info}; } -inline ONNXTensorElementDataType MapTypeInfo::GetMapKeyType() const { +template +inline ONNXTensorElementDataType MapTypeInfoImpl::GetMapKeyType() const { ONNXTensorElementDataType out; - ThrowOnError(GetApi().GetMapKeyType(p_, &out)); + ThrowOnError(GetApi().GetMapKeyType(this->p_, &out)); return out; } -inline TypeInfo MapTypeInfo::GetMapValueType() const { +template +inline TypeInfo MapTypeInfoImpl::GetMapValueType() const { OrtTypeInfo* output; - ThrowOnError(GetApi().GetMapValueType(p_, &output)); + ThrowOnError(GetApi().GetMapValueType(this->p_, &output)); return TypeInfo{output}; } -inline ONNXType TypeInfo::GetONNXType() const { - ONNXType out; - ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(p_, &out)); - return out; +template +inline ConstOptionalTypeInfo TypeInfoImpl::GetOptionalTypeInfo() const { + const OrtOptionalTypeInfo* info; + ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info)); + return ConstOptionalTypeInfo{info}; } +} // namespace detail + +namespace detail { + template -inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { - return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); +template +inline void ConstValueImpl::GetOpaqueData(const char* domain, const char* type_name, R& out) const { + ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R))); } -inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type) { - OrtValue* out; - ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); - return Value{out}; +template +inline bool ConstValueImpl::IsTensor() const { + int out; + ThrowOnError(GetApi().IsTensor(this->p_, &out)); + return out != 0; } -#if !defined(DISABLE_SPARSE_TENSORS) template -inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, - const Shape& values_shape) { - return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType::type); +inline bool ConstValueImpl::HasValue() const { + int out; + ThrowOnError(GetApi().HasValue(this->p_, &out)); + return out != 0; } -inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, - const Shape& values_shape, ONNXTensorElementDataType type) { +template +inline size_t ConstValueImpl::GetCount() const { + size_t out; + ThrowOnError(GetApi().GetValueCount(this->p_, &out)); + return out; +} + +template +inline Value ConstValueImpl::GetValue(int index, OrtAllocator* allocator) const { OrtValue* out; - ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len, - values_shape.shape, values_shape.shape_len, type, &out)); + ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out)); return Value{out}; } -inline void Value::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param, - const int64_t* indices_data, size_t indices_num) { - ThrowOnError(GetApi().FillSparseTensorCoo(p_, mem_info, values_param.values_shape, - values_param.values_shape_len, values_param.data.p_data, - indices_data, indices_num)); +template +inline size_t ConstValueImpl::GetStringTensorDataLength() const { + size_t out; + ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out)); + return out; } -inline void Value::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, - const OrtSparseValuesParam& values, - const int64_t* inner_indices_data, size_t inner_indices_num, - const int64_t* outer_indices_data, size_t outer_indices_num) { - ThrowOnError(GetApi().FillSparseTensorCsr(p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, - inner_indices_data, inner_indices_num, - outer_indices_data, outer_indices_num)); +template +inline size_t ConstValueImpl::GetStringTensorElementLength(size_t element_index) const { + size_t out; + ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out)); + return out; } -inline void Value::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, - const OrtSparseValuesParam& values, - const Shape& indices_shape, - const int32_t* indices_data) { - ThrowOnError(GetApi().FillSparseTensorBlockSparse(p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, - indices_shape.shape, indices_shape.shape_len, - indices_data)); +template +template +inline const R* ConstValueImpl::GetTensorData() const { + R* out; + ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), (void**)&out)); + return out; } -inline void Value::UseCooIndices(int64_t* indices_data, size_t indices_num) { - ThrowOnError(GetApi().UseCooIndices(p_, indices_data, indices_num)); +template +inline const void* ConstValueImpl::GetTensorRawData() const { + void* out; + ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), &out)); + return out; } -inline void Value::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) { - ThrowOnError(GetApi().UseCsrIndices(p_, inner_data, inner_num, outer_data, outer_num)); +template +inline TypeInfo ConstValueImpl::GetTypeInfo() const { + OrtTypeInfo* output; + ThrowOnError(GetApi().GetTypeInfo(this->p_, &output)); + return TypeInfo{output}; +} + +template +inline TensorTypeAndShapeInfo ConstValueImpl::GetTensorTypeAndShapeInfo() const { + OrtTensorTypeAndShapeInfo* output; + ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output)); + return TensorTypeAndShapeInfo{output}; } -inline void Value::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) { - ThrowOnError(GetApi().UseBlockSparseIndices(p_, indices_shape.shape, indices_shape.shape_len, indices_data)); +template +inline ConstMemoryInfo ConstValueImpl::GetTensorMemoryInfo() const { + const OrtMemoryInfo* mem_info; + ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info)); + return ConstMemoryInfo(mem_info); +} + +template +inline void ConstValueImpl::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const { + ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer)); +} + +template +inline std::string ConstValueImpl::GetStringTensorElement(size_t element_index) const { + size_t buffer_length; + ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length)); + + std::string s; + s.resize(buffer_length); + ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0])); + return s; } -inline OrtSparseFormat Value::GetSparseFormat() const { +template +inline void ConstValueImpl::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const { + ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count)); +} + +#if !defined(DISABLE_SPARSE_TENSORS) +template +inline OrtSparseFormat ConstValueImpl::GetSparseFormat() const { OrtSparseFormat format; - ThrowOnError(GetApi().GetSparseTensorFormat(p_, &format)); + ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format)); return format; } -inline TensorTypeAndShapeInfo Value::GetSparseTensorValuesTypeAndShapeInfo() const { +template +inline TensorTypeAndShapeInfo ConstValueImpl::GetSparseTensorValuesTypeAndShapeInfo() const { OrtTensorTypeAndShapeInfo* output; - ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(p_, &output)); + ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output)); return TensorTypeAndShapeInfo{output}; } -inline TensorTypeAndShapeInfo Value::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const { +template +inline TensorTypeAndShapeInfo ConstValueImpl::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const { OrtTensorTypeAndShapeInfo* output; - ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(p_, indices_format, &output)); + ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output)); return TensorTypeAndShapeInfo{output}; } template -inline const T* Value::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const { +template +inline const R* ConstValueImpl::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const { const void* out; - ThrowOnError(GetApi().GetSparseTensorIndices(p_, indices_format, &num_indices, &out)); - return reinterpret_cast(out); + ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out)); + return reinterpret_cast(out); +} + +template +inline bool ConstValueImpl::IsSparseTensor() const { + int out; + ThrowOnError(GetApi().IsSparseTensor(this->p_, &out)); + return out != 0; +} + +template +template +inline const R* ConstValueImpl::GetSparseTensorValues() const { + const void* out; + ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out)); + return reinterpret_cast(out); +} + +#endif + +template +void ValueImpl::FillStringTensor(const char* const* s, size_t s_len) { + ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len)); +} + +template +void ValueImpl::FillStringTensorElement(const char* s, size_t index) { + ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index)); +} + +template +inline char* ValueImpl::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) { + char* result; + ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result)); + return result; +} + +template +void* ValueImpl::GetTensorMutableRawData() { + void* out; + ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out)); + return out; +} + +template +template +R* ValueImpl::GetTensorMutableData() { + R* out; + ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out)); + return out; +} + +template +template +R& ValueImpl::At(const std::vector& location) { + static_assert(!std::is_same::value, "this api does not support std::string"); + R* out; + ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out)); + return *out; } + +#if !defined(DISABLE_SPARSE_TENSORS) +template +void ValueImpl::UseCooIndices(int64_t* indices_data, size_t indices_num) { + ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num)); +} + +template +void ValueImpl::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) { + ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num)); +} + +template +void ValueImpl::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) { + ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data)); +} + +template +void ValueImpl::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param, + const int64_t* indices_data, size_t indices_num) { + ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape, + values_param.values_shape_len, values_param.data.p_data, + indices_data, indices_num)); +} + +template +void ValueImpl::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, + const OrtSparseValuesParam& values, + const int64_t* inner_indices_data, size_t inner_indices_num, + const int64_t* outer_indices_data, size_t outer_indices_num) { + ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, + inner_indices_data, inner_indices_num, + outer_indices_data, outer_indices_num)); +} + +template +void ValueImpl::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, + const OrtSparseValuesParam& values, + const Shape& indices_shape, + const int32_t* indices_data) { + ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, + indices_shape.shape, indices_shape.shape_len, + indices_data)); +} + #endif // !defined(DISABLE_SPARSE_TENSORS) +} // namespace detail + +template +inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { + return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); +} + +inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); + return Value{out}; +} + template inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); @@ -868,6 +1583,21 @@ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, } #if !defined(DISABLE_SPARSE_TENSORS) + +template +inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, + const Shape& values_shape) { + return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType::type); +} + +inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, + const Shape& values_shape, ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len, + values_shape.shape, values_shape.shape_len, type, &out)); + return Value{out}; +} + template inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) { return CreateSparseTensor(allocator, dense_shape, TypeToTensorType::type); @@ -881,16 +1611,16 @@ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& den } #endif // !defined(DISABLE_SPARSE_TENSORS) -inline Value Value::CreateMap(Value& keys, Value& values) { +inline Value Value::CreateMap(const Value& keys, const Value& values) { OrtValue* out; - OrtValue* inputs[2] = {keys, values}; + const OrtValue* inputs[2] = {keys, values}; ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out)); return Value{out}; } -inline Value Value::CreateSequence(std::vector& values) { +inline Value Value::CreateSequence(const std::vector& values) { OrtValue* out; - std::vector values_ort{values.data(), values.data() + values.size()}; + std::vector values_ort{values.data(), values.data() + values.size()}; ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out)); return Value{out}; } @@ -902,292 +1632,539 @@ inline Value Value::CreateOpaque(const char* domain, const char* type_name, cons return Value{out}; } -template -inline void Value::GetOpaqueData(const char* domain, const char* type_name, T& out) const { - ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, p_, &out, sizeof(T))); +// +// Custom OP Inlines +// +inline Logger::Logger(const OrtLogger* logger) : logger_(logger) { + Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_)); } -inline bool Value::IsTensor() const { - int out; - ThrowOnError(GetApi().IsTensor(p_, &out)); - return out != 0; +inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept { + return cached_severity_level_; } -inline bool Value::HasValue() const { - int out; - ThrowOnError(GetApi().HasValue(p_, &out)); - return out != 0; +inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, + const char* func_name, const char* message) const noexcept { + OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number, + func_name); + return Status{status}; } -#if !defined(DISABLE_SPARSE_TENSORS) -inline bool Value::IsSparseTensor() const { - int out; - ThrowOnError(GetApi().IsSparseTensor(p_, &out)); - return out != 0; +// Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security) +// for gcc and clang. The alternative is to use actual C-style variadic parameters and apply +// __attribute__(format(printf...)), which does not work with variadic templates. +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wformat-nonliteral" +#pragma GCC diagnostic ignored "-Wformat-security" +#elif defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" +#pragma clang diagnostic ignored "-Wformat-security" +#endif +template +inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, + int line_number, const char* func_name, const char* format, + Args&&... args) const noexcept { + int msg_len = std::snprintf(nullptr, 0U, format, std::forward(args)...); + + if (msg_len < 0) { // Formatting error + return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL); + } + + OrtStatus* status = nullptr; + const size_t buffer_size = static_cast(msg_len) + 1U; + + constexpr size_t kStackBufferSize = 1024; + + if (buffer_size < kStackBufferSize) { + char buffer[kStackBufferSize]; + snprintf(buffer, kStackBufferSize, format, std::forward(args)...); + status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name); + } else { + // std::make_unique is only supported starting at C++14. +#if (__cplusplus >= 201402L) || (_MSC_VER >= 1900) + auto buffer = std::make_unique(buffer_size); +#else + std::unique_ptr buffer(new char[buffer_size]); +#endif + std::snprintf(buffer.get(), buffer_size, format, std::forward(args)...); + status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name); + } + + return Status{status}; } +// Re-enable -Wformat-nonliteral and -Wformat-security +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#elif defined(__clang__) +#pragma clang diagnostic pop #endif -inline size_t Value::GetCount() const { - size_t out; - ThrowOnError(GetApi().GetValueCount(p_, &out)); +inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) { +} + +inline size_t KernelContext::GetInputCount() const { + size_t out = 0; + Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out)); return out; } -inline Value Value::GetValue(int index, OrtAllocator* allocator) const { - OrtValue* out; - ThrowOnError(GetApi().GetValue(p_, index, allocator, &out)); - return Value{out}; +inline size_t KernelContext::GetOutputCount() const { + size_t out = 0; + Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out)); + return out; } -inline size_t Value::GetStringTensorDataLength() const { - size_t out; - ThrowOnError(GetApi().GetStringTensorDataLength(p_, &out)); +inline ConstValue KernelContext::GetInput(size_t index) const { + const OrtValue* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out)); + return ConstValue{out}; +} + +inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const { + OrtValue* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out)); + return UnownedValue(out); +} + +inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector& dims) const { + OrtValue* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out)); + return UnownedValue(out); +} + +inline void* KernelContext::GetGPUComputeStream() const { + void* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out)); return out; } -inline size_t Value::GetStringTensorElementLength(size_t element_index) const { - size_t out; - ThrowOnError(GetApi().GetStringTensorElementLength(p_, element_index, &out)); +inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const { + OrtAllocator* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out)); return out; } -inline void Value::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const { - ThrowOnError(GetApi().GetStringTensorContent(p_, buffer, buffer_length, offsets, offsets_count)); +inline Logger KernelContext::GetLogger() const { + const OrtLogger* out = nullptr; + ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out)); + return Logger{out}; } -inline void Value::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const { - ThrowOnError(GetApi().GetStringTensorElement(p_, buffer_length, element_index, buffer)); +inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const { + ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data)); } -inline void Value::FillStringTensor(const char* const* s, size_t s_len) { - ThrowOnError(GetApi().FillStringTensor(p_, s, s_len)); +inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) { + Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_)); } -inline void Value::FillStringTensorElement(const char* s, size_t index) { - ThrowOnError(GetApi().FillStringTensorElement(p_, s, index)); +namespace detail { +template +inline KernelInfo KernelInfoImpl::Copy() const { + OrtKernelInfo* info_copy = nullptr; + Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy)); + return KernelInfo{info_copy}; } template -T* Value::GetTensorMutableData() { - T* out; - ThrowOnError(GetApi().GetTensorMutableData(p_, (void**)&out)); +inline size_t KernelInfoImpl::GetInputCount() const { + size_t out = 0; + ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out)); return out; } template -const T* Value::GetTensorData() const { - T* out; - ThrowOnError(GetApi().GetTensorMutableData(p_, (void**)&out)); +inline size_t KernelInfoImpl::GetOutputCount() const { + size_t out = 0; + ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out)); return out; } -#if !defined(DISABLE_SPARSE_TENSORS) template -inline const T* Value::GetSparseTensorValues() const { - const void* out; - ThrowOnError(GetApi().GetSparseTensorValues(p_, &out)); - return reinterpret_cast(out); +inline std::string KernelInfoImpl::GetInputName(size_t index) const { + size_t size = 0; + + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + + return out; } -#endif // !defined(DISABLE_SPARSE_TENSORS) template -inline T& Value::At(const std::vector& location) { - static_assert(!std::is_same::value, "this api does not support std::string"); - T* out; - ThrowOnError(GetApi().TensorAt(p_, location.data(), location.size(), (void**)&out)); - return *out; -} +inline std::string KernelInfoImpl::GetOutputName(size_t index) const { + size_t size = 0; -inline TypeInfo Value::GetTypeInfo() const { - OrtTypeInfo* output; - ThrowOnError(GetApi().GetTypeInfo(p_, &output)); - return TypeInfo{output}; + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + + return out; } -inline TensorTypeAndShapeInfo Value::GetTensorTypeAndShapeInfo() const { - OrtTensorTypeAndShapeInfo* output; - ThrowOnError(GetApi().GetTensorTypeAndShape(p_, &output)); - return TensorTypeAndShapeInfo{output}; +template +inline TypeInfo KernelInfoImpl::GetInputTypeInfo(size_t index) const { + OrtTypeInfo* out = nullptr; + ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out)); + return TypeInfo{out}; } -// -// Custom OP API Inlines -// -inline void CustomOpApi::ThrowOnError(OrtStatus* status) { - Ort::ThrowOnError(api_, status); +template +inline TypeInfo KernelInfoImpl::GetOutputTypeInfo(size_t index) const { + OrtTypeInfo* out = nullptr; + ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out)); + return TypeInfo{out}; } -template <> -inline float CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { - float out; - ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out)); - return out; +template +inline Value KernelInfoImpl::GetTensorAttribute(const char* name, OrtAllocator* allocator) const { + OrtValue* out = nullptr; + ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out)); + return Value{out}; } -template <> -inline int64_t CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { - int64_t out; - ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out)); - return out; +template +inline ConstValue KernelInfoImpl::GetTensorConstantInput(size_t index, int* is_constant) const { + const OrtValue* out = nullptr; + ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out)); + return ConstValue{out}; } -template <> -inline std::string CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { +template +inline std::string KernelInfoImpl::GetNodeName() const { size_t size = 0; - std::string out; - // Feed nullptr for the data buffer to query the true size of the string attribute - OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size); + // Feed nullptr for the data buffer to query the true size of the string value + Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size)); + + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' - if (status == nullptr) { - out.resize(size); - ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size)); - out.resize(size - 1); // remove the terminating character '\0' - } else { - ThrowOnError(status); - } return out; } -template <> -inline std::vector CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { +template +inline Logger KernelInfoImpl::GetLogger() const { + const OrtLogger* out = nullptr; + ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out)); + return Logger{out}; +} + +inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) { + Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out)); +} + +inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) { + Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out)); +} + +inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) { size_t size = 0; - std::vector out; + // Feed nullptr for the data buffer to query the true size of the string attribute + Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size)); + std::string out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + out.swap(result); +} + +inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector& result) { + size_t size = 0; // Feed nullptr for the data buffer to query the true size of the attribute - OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size); + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size)); - if (status == nullptr) { - out.resize(size); - ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size)); - } else { - ThrowOnError(status); - } - return out; + std::vector out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size)); + out.swap(result); } -template <> -inline std::vector CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { +inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector& result) { size_t size = 0; - std::vector out; // Feed nullptr for the data buffer to query the true size of the attribute - OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size); + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size)); - if (status == nullptr) { - out.resize(size); - ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size)); - } else { - ThrowOnError(status); - } - return out; -} -inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) { - OrtTensorTypeAndShapeInfo* out; - ThrowOnError(api_.GetTensorTypeAndShape(value, &out)); - return out; + std::vector out; + out.resize(size); + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size)); + out.swap(result); +} +} // namespace detail + +inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl{info} {} + +inline Op::Op(OrtOp* p) : Base(p) {} + +inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version, + const char** type_constraint_names, + const ONNXTensorElementDataType* type_constraint_values, + size_t type_constraint_count, + const OpAttr* attr_values, size_t attr_count, + size_t input_count, size_t output_count) { + static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*), + "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely"); + auto attr_input_values = reinterpret_cast(attr_values); + OrtOp* op; + Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values, + static_cast(type_constraint_count), + attr_input_values, + static_cast(attr_count), + static_cast(input_count), + static_cast(output_count), &op)); + return Op{op}; +} + +inline void Op::Invoke(const OrtKernelContext* context, + const Value* input_values, + size_t input_count, + Value* output_values, + size_t output_count) { + static_assert(sizeof(Value) == sizeof(OrtValue*), + "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); + auto ort_input_values = reinterpret_cast(input_values); + auto ort_output_values = reinterpret_cast(output_values); + Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast(input_count), + ort_output_values, static_cast(output_count))); } -inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) { - size_t out; - ThrowOnError(api_.GetTensorShapeElementCount(info, &out)); - return out; +inline void Op::Invoke(const OrtKernelContext* context, + const OrtValue* const* input_values, + size_t input_count, + OrtValue* const* output_values, + size_t output_count) { + Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast(input_count), + output_values, static_cast(output_count))); } -inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) { - ONNXTensorElementDataType out; - ThrowOnError(api_.GetTensorElementType(info, &out)); - return out; +inline std::string GetVersionString() { + return OrtGetApiBase()->GetVersionString(); } -inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) { - size_t out; - ThrowOnError(api_.GetDimensionsCount(info, &out)); - return out; +inline std::string GetBuildInfoString() { + return GetApi().GetBuildInfoString(); } -inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) { - ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length)); -} +inline std::vector GetAvailableProviders() { + char** providers; + int len; -inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) { - ThrowOnError(api_.SetDimensions(info, dim_values, dim_count)); -} + auto release_fn = [&len](char** providers) { + // This should always return nullptr. + ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len)); + }; -template -inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) { - T* data; - ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast(&data))); - return data; + ThrowOnError(GetApi().GetAvailableProviders(&providers, &len)); + std::unique_ptr guard(providers, release_fn); + std::vector available_providers; + available_providers.reserve(static_cast(len)); + for (int i = 0; i < len; ++i) { + available_providers.emplace_back(providers[i]); + } + return available_providers; } -inline const OrtMemoryInfo* CustomOpApi::GetTensorMemoryInfo(_In_ const OrtValue* value) { - const OrtMemoryInfo* mem_info; - ThrowOnError(api_.GetTensorMemoryInfo(value, &mem_info)); - return mem_info; -} +template +void CustomOpBase::GetSessionConfigs(std::unordered_map& out, + ConstSessionOptions options) const { + const TOp* derived = static_cast(this); + std::vector keys = derived->GetSessionConfigKeys(); -template -inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) { - return GetTensorMutableData(const_cast(value)); -} + out.reserve(keys.size()); -inline std::vector CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) { - std::vector output(GetDimensionsCount(info)); - GetDimensions(info, output.data(), output.size()); - return output; -} + std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), ""); + const size_t prefix_size = config_entry_key.length(); -inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) { - api_.ReleaseTensorTypeAndShapeInfo(input); + for (const auto& key : keys) { + config_entry_key.resize(prefix_size); + config_entry_key.append(key); + out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), ""); + } } -inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) { - size_t out; - ThrowOnError(api_.KernelContext_GetInputCount(context, &out)); - return out; +inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api, + OrtShapeInferContext* ctx) : ort_api_(ort_api), ctx_(ctx) { + size_t input_count = 0; + Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputCount(ctx_, &input_count)); + for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { + OrtTensorTypeAndShapeInfo* info{}; + Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputTypeShape(ctx, ith_input, &info)); + TensorTypeAndShapeInfo type_shape_info(info); + auto integer_shape = type_shape_info.GetShape(); + std::vector symbolic_shape(integer_shape.size(), {}); + if (!integer_shape.empty()) { + type_shape_info.GetSymbolicDimensions(&symbolic_shape[0], integer_shape.size()); + } + Shape shape; + for (size_t ith = 0; ith < integer_shape.size(); ++ith) { + if (symbolic_shape[ith] && std::string{symbolic_shape[ith]}.size() > 0) { + shape.emplace_back(symbolic_shape[ith]); + } else { + shape.emplace_back(integer_shape[ith]); + } + } + input_shapes_.push_back(std::move(shape)); + type_shape_info.release(); + } } -inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) { - const OrtValue* out; - ThrowOnError(api_.KernelContext_GetInput(context, index, &out)); - return out; -} +inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type) { + OrtTensorTypeAndShapeInfo* info = {}; + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info)); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetTensorElementType(info, type)); -inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) { - size_t out; - ThrowOnError(api_.KernelContext_GetOutputCount(context, &out)); - return out; + using InfoPtr = std::unique_ptr>; + + InfoPtr info_ptr(info, [this](OrtTensorTypeAndShapeInfo* obj) { + ort_api_->ReleaseTensorTypeAndShapeInfo(obj); + }); + + std::vector integer_dims; + std::vector symbolic_dims; + + for (const auto dim : shape) { + if (dim.IsInt()) { + integer_dims.push_back(dim.AsInt()); + symbolic_dims.push_back(""); + } else { + if (!dim.AsSym() || std::string{dim.AsSym()}.empty()) { + ORT_CXX_API_THROW("Symbolic dim must not be an empty string", ORT_INVALID_ARGUMENT); + } + integer_dims.push_back(SymbolicInteger::INVALID_INT_DIM); + symbolic_dims.push_back(dim.AsSym()); + } + } + + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size())); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size())); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info)); + return Status{nullptr}; +} + +inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + int64_t i = {}; + size_t out = {}; + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out)); + return i; +} + +inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + int64_t i = {}; + size_t out = {}; + // first call to get the bytes needed + // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. + // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). + // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out); + if (status) { + size_t num_i = out / sizeof(int64_t); + ShapeInferContext::Ints ints(num_i, 0); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out)); + return ints; + } else { + if (out == 0u) { + return {}; + } + return {i}; + } } -inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, - _In_ const int64_t* dim_values, size_t dim_count) { - OrtValue* out; - ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out)); - return out; +inline float ShapeInferContext::GetAttrFloat(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + float f = {}; + size_t out = {}; + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out)); + return f; } -inline void* CustomOpApi::KernelContext_GetGPUComputeStream(const OrtKernelContext* context) { - void* out; - ThrowOnError(api_.KernelContext_GetGPUComputeStream(context, &out)); - return out; +inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + float f = {}; + size_t out = {}; + // first call to get the bytes needed + // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. + // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). + // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out); + if (status) { + size_t num_f = out / sizeof(float); + ShapeInferContext::Floats floats(num_f, 0); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out)); + return floats; + } else { + if (out == 0u) { + return {}; + } + return {f}; + } } -inline SessionOptions& SessionOptions::DisablePerSessionThreads() { - ThrowOnError(GetApi().DisablePerSessionThreads(p_)); - return *this; +inline std::string ShapeInferContext::GetAttrString(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + char c = {}; + size_t out = {}; + // first call to get the bytes needed + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out); + if (status) { + std::vector chars(out, '\0'); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out)); + return {chars.data()}; + } else { + return {c}; + } } -inline std::vector GetAvailableProviders() { - int len; - char** providers; - const OrtApi& api = GetApi(); - ThrowOnError(api.GetAvailableProviders(&providers, &len)); - std::vector available_providers(providers, providers + len); - ThrowOnError(api.ReleaseAvailableProviders(providers, len)); - return available_providers; +inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) { + const auto* attr = GetAttrHdl(attr_name); + char c = {}; + size_t out = {}; + // first call to get the bytes needed + // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. + // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). + // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. + auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out); + if (status) { + std::vector chars(out, '\0'); + Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out)); + ShapeInferContext::Strings strings; + char* char_st = chars.data(); + char* char_ed = char_st + out; + while (char_st < char_ed) { + strings.emplace_back(char_st); + while (*char_st != '\0') { + char_st++; + } + char_st++; + } + return strings; + } else { + if (out == 0u) { + return {}; + } + return {std::string{c}}; + } } -SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val); +inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const { + const OrtOpAttr* attr_hdl = {}; + Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl)); + return attr_hdl; +} } // namespace Ort diff --git a/libs/onnxruntime/include/onnxruntime_float16.h b/libs/onnxruntime/include/onnxruntime_float16.h new file mode 100644 index 0000000..408d3cc --- /dev/null +++ b/libs/onnxruntime/include/onnxruntime_float16.h @@ -0,0 +1,535 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +namespace onnxruntime_float16 { + +namespace detail { + +enum class endian { +#if defined(_WIN32) + little = 0, + big = 1, + native = little, +#elif defined(__GNUC__) || defined(__clang__) + little = __ORDER_LITTLE_ENDIAN__, + big = __ORDER_BIG_ENDIAN__, + native = __BYTE_ORDER__, +#else +#error onnxruntime_float16::detail::endian is not implemented in this environment. +#endif +}; + +static_assert( + endian::native == endian::little || endian::native == endian::big, + "Only little-endian or big-endian native byte orders are supported."); + +} // namespace detail + +/// +/// Shared implementation between public and internal classes. CRTP pattern. +/// +template +struct Float16Impl { + protected: + /// + /// Converts from float to uint16_t float16 representation + /// + /// + /// + constexpr static uint16_t ToUint16Impl(float v) noexcept; + + /// + /// Converts float16 to float + /// + /// float representation of float16 value + float ToFloatImpl() const noexcept; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + uint16_t AbsImpl() const noexcept { + return static_cast(val & ~kSignMask); + } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + uint16_t NegateImpl() const noexcept { + return IsNaN() ? val : static_cast(val ^ kSignMask); + } + + public: + // uint16_t special values + static constexpr uint16_t kSignMask = 0x8000U; + static constexpr uint16_t kBiasedExponentMask = 0x7C00U; + static constexpr uint16_t kPositiveInfinityBits = 0x7C00U; + static constexpr uint16_t kNegativeInfinityBits = 0xFC00U; + static constexpr uint16_t kPositiveQNaNBits = 0x7E00U; + static constexpr uint16_t kNegativeQNaNBits = 0xFE00U; + static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number + static constexpr uint16_t kOneBits = 0x3C00U; + static constexpr uint16_t kMinusOneBits = 0xBC00U; + + uint16_t val{0}; + + Float16Impl() = default; + + /// + /// Checks if the value is negative + /// + /// true if negative + bool IsNegative() const noexcept { + return static_cast(val) < 0; + } + + /// + /// Tests if the value is NaN + /// + /// true if NaN + bool IsNaN() const noexcept { + return AbsImpl() > kPositiveInfinityBits; + } + + /// + /// Tests if the value is finite + /// + /// true if finite + bool IsFinite() const noexcept { + return AbsImpl() < kPositiveInfinityBits; + } + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + bool IsPositiveInfinity() const noexcept { + return val == kPositiveInfinityBits; + } + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + bool IsNegativeInfinity() const noexcept { + return val == kNegativeInfinityBits; + } + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + bool IsInfinity() const noexcept { + return AbsImpl() == kPositiveInfinityBits; + } + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + bool IsNaNOrZero() const noexcept { + auto abs = AbsImpl(); + return (abs == 0 || abs > kPositiveInfinityBits); + } + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + bool IsNormal() const noexcept { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) + } + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + bool IsSubnormal() const noexcept { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent) + } + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept { + return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; + } + + bool operator==(const Float16Impl& rhs) const noexcept { + if (IsNaN() || rhs.IsNaN()) { + // IEEE defines that NaN is not equal to anything, including itself. + return false; + } + return val == rhs.val; + } + + bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); } + + bool operator<(const Float16Impl& rhs) const noexcept { + if (IsNaN() || rhs.IsNaN()) { + // IEEE defines that NaN is unordered with respect to everything, including itself. + return false; + } + + const bool left_is_negative = IsNegative(); + if (left_is_negative != rhs.IsNegative()) { + // When the signs of left and right differ, we know that left is less than right if it is + // the negative value. The exception to this is if both values are zero, in which case IEEE + // says they should be equal, even if the signs differ. + return left_is_negative && !AreZero(*this, rhs); + } + return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); + } +}; + +// The following Float16_t conversions are based on the code from +// Eigen library. + +// The conversion routines are Copyright (c) Fabian Giesen, 2016. +// The original license follows: +// +// Copyright (c) Fabian Giesen, 2016 +// All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted. +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +namespace detail { +union float32_bits { + unsigned int u; + float f; +}; +} // namespace detail + +template +inline constexpr uint16_t Float16Impl::ToUint16Impl(float v) noexcept { + detail::float32_bits f{}; + f.f = v; + + constexpr detail::float32_bits f32infty = {255 << 23}; + constexpr detail::float32_bits f16max = {(127 + 16) << 23}; + constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; + constexpr unsigned int sign_mask = 0x80000000u; + uint16_t val = static_cast(0x0u); + + unsigned int sign = f.u & sign_mask; + f.u ^= sign; + + // NOTE all the integer compares in this function can be safely + // compiled into signed compares since all operands are below + // 0x80000000. Important if you want fast straight SSE2 code + // (since there's no unsigned PCMPGTD). + + if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set) + val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf + } else { // (De)normalized number or zero + if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.f += denorm_magic.f; + + // and one integer subtract of the bias later, we have our final float! + val = static_cast(f.u - denorm_magic.u); + } else { + unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but + // without arithmetic overflow. + f.u += 0xc8000fffU; + // rounding bias part 2 + f.u += mant_odd; + // take the bits! + val = static_cast(f.u >> 13); + } + } + + val |= static_cast(sign >> 16); + return val; +} + +template +inline float Float16Impl::ToFloatImpl() const noexcept { + constexpr detail::float32_bits magic = {113 << 23}; + constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift + detail::float32_bits o{}; + + o.u = (val & 0x7fff) << 13; // exponent/mantissa bits + unsigned int exp = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // re-normalize + } + + // Attempt to workaround the Internal Compiler Error on ARM64 + // for bitwise | operator, including std::bitset +#if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC) + if (IsNegative()) { + return -o.f; + } +#else + // original code: + o.u |= (val & 0x8000U) << 16U; // sign bit +#endif + return o.f; +} + +/// Shared implementation between public and internal classes. CRTP pattern. +template +struct BFloat16Impl { + protected: + /// + /// Converts from float to uint16_t float16 representation + /// + /// + /// + static uint16_t ToUint16Impl(float v) noexcept; + + /// + /// Converts bfloat16 to float + /// + /// float representation of bfloat16 value + float ToFloatImpl() const noexcept; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + uint16_t AbsImpl() const noexcept { + return static_cast(val & ~kSignMask); + } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + uint16_t NegateImpl() const noexcept { + return IsNaN() ? val : static_cast(val ^ kSignMask); + } + + public: + // uint16_t special values + static constexpr uint16_t kSignMask = 0x8000U; + static constexpr uint16_t kBiasedExponentMask = 0x7F80U; + static constexpr uint16_t kPositiveInfinityBits = 0x7F80U; + static constexpr uint16_t kNegativeInfinityBits = 0xFF80U; + static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U; + static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U; + static constexpr uint16_t kMaxValueBits = 0x7F7FU; + static constexpr uint16_t kRoundToNearest = 0x7FFFU; + static constexpr uint16_t kOneBits = 0x3F80U; + static constexpr uint16_t kMinusOneBits = 0xBF80U; + + uint16_t val{0}; + + BFloat16Impl() = default; + + /// + /// Checks if the value is negative + /// + /// true if negative + bool IsNegative() const noexcept { + return static_cast(val) < 0; + } + + /// + /// Tests if the value is NaN + /// + /// true if NaN + bool IsNaN() const noexcept { + return AbsImpl() > kPositiveInfinityBits; + } + + /// + /// Tests if the value is finite + /// + /// true if finite + bool IsFinite() const noexcept { + return AbsImpl() < kPositiveInfinityBits; + } + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + bool IsPositiveInfinity() const noexcept { + return val == kPositiveInfinityBits; + } + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + bool IsNegativeInfinity() const noexcept { + return val == kNegativeInfinityBits; + } + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + bool IsInfinity() const noexcept { + return AbsImpl() == kPositiveInfinityBits; + } + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + bool IsNaNOrZero() const noexcept { + auto abs = AbsImpl(); + return (abs == 0 || abs > kPositiveInfinityBits); + } + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + bool IsNormal() const noexcept { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) + } + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + bool IsSubnormal() const noexcept { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent) + } + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept { + // IEEE defines that positive and negative zero are equal, this gives us a quick equality check + // for two values by or'ing the private bits together and stripping the sign. They are both zero, + // and therefore equivalent, if the resulting value is still zero. + return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; + } +}; + +template +inline uint16_t BFloat16Impl::ToUint16Impl(float v) noexcept { + uint16_t result; + if (std::isnan(v)) { + result = kPositiveQNaNBits; + } else { + auto get_msb_half = [](float fl) { + uint16_t result; +#ifdef __cpp_if_constexpr + if constexpr (detail::endian::native == detail::endian::little) { +#else + if (detail::endian::native == detail::endian::little) { +#endif + std::memcpy(&result, reinterpret_cast(&fl) + sizeof(uint16_t), sizeof(uint16_t)); + } else { + std::memcpy(&result, &fl, sizeof(uint16_t)); + } + return result; + }; + + uint16_t upper_bits = get_msb_half(v); + union { + uint32_t U32; + float F32; + }; + F32 = v; + U32 += (upper_bits & 1) + kRoundToNearest; + result = get_msb_half(F32); + } + return result; +} + +template +inline float BFloat16Impl::ToFloatImpl() const noexcept { + if (IsNaN()) { + return std::numeric_limits::quiet_NaN(); + } + float result; + char* const first = reinterpret_cast(&result); + char* const second = first + sizeof(uint16_t); +#ifdef __cpp_if_constexpr + if constexpr (detail::endian::native == detail::endian::little) { +#else + if (detail::endian::native == detail::endian::little) { +#endif + std::memset(first, 0, sizeof(uint16_t)); + std::memcpy(second, &val, sizeof(uint16_t)); + } else { + std::memcpy(first, &val, sizeof(uint16_t)); + std::memset(second, 0, sizeof(uint16_t)); + } + return result; +} + +} // namespace onnxruntime_float16 diff --git a/libs/onnxruntime/include/onnxruntime_lite_custom_op.h b/libs/onnxruntime/include/onnxruntime_lite_custom_op.h new file mode 100644 index 0000000..ce87d8c --- /dev/null +++ b/libs/onnxruntime/include/onnxruntime_lite_custom_op.h @@ -0,0 +1,1119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Summary +// The header has APIs to save custom op authors the trouble of defining schemas, +// which will be inferred by functions' signature, as long as their argument list has types supported here. +// Input could be: +// 1. Tensor of onnx data types. +// 2. Span of onnx data types. +// 3. Scalar of onnx data types. +// A input could be optional if indicated as std::optional<...>. +// For an output, it must be a tensor of onnx data types. +// Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op. +// For concrete examples, please search keyword "LiteCustomOpTest" under "/onnxruntime/test/". +// Note - all APIs in this header are ABI. + +#pragma once +#include "onnxruntime_cxx_api.h" +#include +#include +#include +#include + +namespace Ort { +namespace Custom { + +class ArgBase { + public: + ArgBase(OrtKernelContext* ctx, + size_t indice, + bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {} + virtual ~ArgBase() {}; + + protected: + struct KernelContext ctx_; + size_t indice_; + bool is_input_; +}; + +using ArgPtr = std::unique_ptr; +using ArgPtrs = std::vector; + +class TensorBase : public ArgBase { + public: + TensorBase(OrtKernelContext* ctx, + size_t indice, + bool is_input) : ArgBase(ctx, indice, is_input) {} + + operator bool() const { + return shape_.has_value(); + } + + const std::vector& Shape() const { + if (!shape_.has_value()) { + ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return shape_.value(); + } + + ONNXTensorElementDataType Type() const { + return type_; + } + + int64_t NumberOfElement() const { + if (shape_.has_value()) { + return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); + } else { + return 0; + } + } + + std::string Shape2Str() const { + if (shape_.has_value()) { + std::string shape_str; + for (const auto& dim : *shape_) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + + bool IsCpuTensor() const { + return strcmp("Cpu", mem_type_) == 0; + } + + virtual const void* DataRaw() const = 0; + virtual size_t SizeInBytes() const = 0; + + protected: + std::optional> shape_; + ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + const char* mem_type_ = "Cpu"; +}; + +template +struct Span { + const T* data_ = {}; + size_t size_ = {}; + void Assign(const T* data, size_t size) { + data_ = data; + size_ = size; + } + size_t size() const { return size_; } + T operator[](size_t indice) const { + return data_[indice]; + } + const T* data() const { return data_; } +}; + +template +class Tensor : public TensorBase { + public: + using TT = typename std::remove_reference::type; + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { + if (is_input_) { + if (indice >= ctx_.GetInputCount()) { + ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + const_value_ = ctx_.GetInput(indice); + auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo(); + shape_ = type_shape_info.GetShape(); + } + } + const TT* Data() const { + return reinterpret_cast(const_value_.GetTensorRawData()); + } + TT* Allocate(const std::vector& shape) { + shape_ = shape; + if (!data_) { + shape_ = shape; + data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData(); + } + return data_; + } + static TT GetT() { return (TT)0; } + const Span& AsSpan() { + if (!shape_.has_value() || shape_->size() != 1) { + ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor", + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + span_.Assign(Data(), static_cast((*shape_)[0])); + return span_; + } + const T& AsScalar() { + if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) { + ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor", + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return *Data(); + } + const void* DataRaw() const override { + return reinterpret_cast(Data()); + } + + size_t SizeInBytes() const override { + return sizeof(TT) * static_cast(NumberOfElement()); + } + + private: + ConstValue const_value_; // for input + TT* data_{}; // for output + Span span_; +}; + +template <> +class Tensor : public TensorBase { + public: + using strings = std::vector; + + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { + if (is_input_) { + if (indice >= ctx_.GetInputCount()) { + ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + auto const_value = ctx_.GetInput(indice); + auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); + shape_ = type_shape_info.GetShape(); + auto num_chars = const_value.GetStringTensorDataLength(); + // note - there will be copy ... + auto num_strings = static_cast(NumberOfElement()); + if (num_strings) { + std::vector chars(num_chars + 1, '\0'); + std::vector offsets(num_strings); + const_value.GetStringTensorContent(static_cast(chars.data()), num_chars, offsets.data(), offsets.size()); + auto upper_bound = num_strings - 1; + input_strings_.resize(num_strings); + for (size_t i = upper_bound;; --i) { + if (i < upper_bound) { + chars[offsets[i + 1]] = '\0'; + } + input_strings_[i] = chars.data() + offsets[i]; + if (0 == i) { + break; + } + } + } + } + } + const strings& Data() const { + return input_strings_; + } + const void* DataRaw() const override { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return reinterpret_cast(input_strings_[0].c_str()); + } + size_t SizeInBytes() const override { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return input_strings_[0].size(); + } + void SetStringOutput(const strings& ss, const std::vector& dims) { + shape_ = dims; + std::vector raw; + for (const auto& s : ss) { + raw.push_back(s.data()); + } + auto output = ctx_.GetOutput(indice_, dims.data(), dims.size()); + // note - there will be copy ... + output.FillStringTensor(raw.data(), raw.size()); + } + const Span& AsSpan() { + ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + const std::string& AsScalar() { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor", + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return input_strings_[0]; + } + + private: + std::vector input_strings_; // for input +}; + +template <> +class Tensor : public TensorBase { + public: + using strings = std::vector; + using string_views = std::vector; + + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { + if (is_input_) { + if (indice >= ctx_.GetInputCount()) { + ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + auto const_value = ctx_.GetInput(indice); + auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); + shape_ = type_shape_info.GetShape(); + auto num_chars = const_value.GetStringTensorDataLength(); + chars_.resize(num_chars + 1, '\0'); + auto num_strings = static_cast(NumberOfElement()); + if (num_strings) { + std::vector offsets(num_strings); + const_value.GetStringTensorContent(static_cast(chars_.data()), num_chars, offsets.data(), offsets.size()); + offsets.push_back(num_chars); + for (size_t i = 0; i < num_strings; ++i) { + input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]); + } + } + } + } + const string_views& Data() const { + return input_string_views_; + } + const void* DataRaw() const override { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return reinterpret_cast(input_string_views_[0].data()); + } + size_t SizeInBytes() const override { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return input_string_views_[0].size(); + } + void SetStringOutput(const strings& ss, const std::vector& dims) { + shape_ = dims; + std::vector raw; + for (const auto& s : ss) { + raw.push_back(s.data()); + } + auto output = ctx_.GetOutput(indice_, dims.data(), dims.size()); + // note - there will be copy ... + output.FillStringTensor(raw.data(), raw.size()); + } + const Span& AsSpan() { + ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + std::string_view AsScalar() { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor", + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return input_string_views_[0]; + } + + private: + std::vector chars_; // for input + std::vector input_string_views_; // for input +}; + +using TensorPtr = std::unique_ptr; +using TensorPtrs = std::vector; + +struct TensorArray : public ArgBase { + TensorArray(OrtKernelContext* ctx, + size_t start_indice, + bool is_input) : ArgBase(ctx, + start_indice, + is_input) { + if (is_input) { + auto input_count = ctx_.GetInputCount(); + for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) { + auto const_value = ctx_.GetInput(start_indice); + auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); + auto type = type_shape_info.GetElementType(); + TensorPtr tensor; + switch (type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + tensor = std::make_unique>(ctx, ith_input, true); + break; + default: + ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); + break; + } + tensors_.emplace_back(tensor.release()); + } // for + } + } + template + T* AllocateOutput(size_t ith_output, const std::vector& shape) { + // ith_output is the indice of output relative to the tensor array + // indice_ + ith_output is the indice relative to context + auto tensor = std::make_unique>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false); + auto raw_output = tensor.get()->Allocate(shape); + tensors_.emplace_back(tensor.release()); + return raw_output; + } + Tensor& AllocateStringTensor(size_t ith_output) { + // ith_output is the indice of output relative to the tensor array + // indice_ + ith_output is the indice relative to context + auto tensor = std::make_unique>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false); + Tensor& output = *tensor; + tensors_.emplace_back(tensor.release()); + return output; + } + size_t Size() const { + return tensors_.size(); + } + const TensorPtr& operator[](size_t ith_input) const { + // ith_input is the indice of output relative to the tensor array + return tensors_.at(ith_input); + } + + private: + TensorPtrs tensors_; +}; + +using Variadic = TensorArray; + +/* +Note: +OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core. +The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so: +1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierarchy. +2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp, + hence memory could still be recycled properly. +Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety. +*/ +struct OrtLiteCustomOp : public OrtCustomOp { + using ConstOptionalFloatTensor = std::optional&>; + using OptionalFloatTensor = std::optional>; + + // CreateTuple + template + static typename std::enable_if>::type + CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) { + return std::make_tuple(); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + std::tuple current = std::tuple{context}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + std::tuple current = std::tuple{*context}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + +#ifdef ORT_CUDA_CTX + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + thread_local CudaContext cuda_context; + cuda_context.Init(*context); + std::tuple current = std::tuple{cuda_context}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } +#endif + +#ifdef ORT_ROCM_CTX + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + thread_local RocmContext rocm_context; + rocm_context.Init(*context); + std::tuple current = std::tuple{rocm_context}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } +#endif + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_input, true)); + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_input, true)); + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_output, false)); + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_output, false)); + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + +#define CREATE_TUPLE_INPUT(data_type) \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{&reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{&reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } \ + template \ + static typename std::enable_if::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsScalar()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsScalar()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } +#define CREATE_TUPLE_OUTPUT(data_type) \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_output < num_output) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } +#define CREATE_TUPLE(data_type) \ + CREATE_TUPLE_INPUT(data_type) \ + CREATE_TUPLE_OUTPUT(data_type) + + CREATE_TUPLE(bool) + CREATE_TUPLE(float) + CREATE_TUPLE(Ort::Float16_t) + CREATE_TUPLE(Ort::BFloat16_t) + CREATE_TUPLE(double) + CREATE_TUPLE(int8_t) + CREATE_TUPLE(int16_t) + CREATE_TUPLE(int32_t) + CREATE_TUPLE(int64_t) + CREATE_TUPLE(uint8_t) + CREATE_TUPLE(uint16_t) + CREATE_TUPLE(uint32_t) + CREATE_TUPLE(uint64_t) + CREATE_TUPLE(std::string) + CREATE_TUPLE_INPUT(std::string_view) + CREATE_TUPLE(Ort::Float8E4M3FN_t) + CREATE_TUPLE(Ort::Float8E4M3FNUZ_t) + CREATE_TUPLE(Ort::Float8E5M2_t) + CREATE_TUPLE(Ort::Float8E5M2FNUZ_t) + + // ParseArgs ... + template + static typename std::enable_if<0 == sizeof...(Ts)>::type + ParseArgs(std::vector&, std::vector&) { + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } + +#ifdef ORT_CUDA_CTX + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } +#endif + +#ifdef ORT_ROCM_CTX + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } +#endif + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + +#define PARSE_INPUT_BASE(pack_type, onnx_type) \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + input_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + input_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + input_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } + +#define PARSE_INPUT(data_type, onnx_type) \ + PARSE_INPUT_BASE(const Custom::Tensor*, onnx_type) \ + PARSE_INPUT_BASE(const Custom::Tensor&, onnx_type) \ + PARSE_INPUT_BASE(const Custom::Span*, onnx_type) \ + PARSE_INPUT_BASE(const Custom::Span&, onnx_type) \ + PARSE_INPUT_BASE(data_type, onnx_type) + +#define PARSE_OUTPUT(data_type, onnx_type) \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same*>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + output_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same&>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + output_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same*>>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + output_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } + +#define PARSE_ARGS(data_type, onnx_type) \ + PARSE_INPUT(data_type, onnx_type) \ + PARSE_OUTPUT(data_type, onnx_type) + + PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) + PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) + PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) + PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) + PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) + PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) + PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) + PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) + PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) + PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) + PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) + PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) + PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) + PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output + PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN) + PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ) + PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) + PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ) + + OrtLiteCustomOp(const char* op_name, + const char* execution_provider, + ShapeInferFn shape_infer_fn, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + shape_infer_fn_(shape_infer_fn), + start_ver_(start_ver), + end_ver_(end_ver) { + OrtCustomOp::version = ORT_API_VERSION; + + OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; + OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); }; + OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; }; + + OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->input_types_.size(); + }; + + OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->input_types_[indice]; + }; + + OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->output_types_.size(); + }; + + OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->output_types_[indice]; + }; + + OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL; + }; + + OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL; + }; + + OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { + return 1; + }; + + OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { + return 0; + }; + + OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { + return 1; + }; + + OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { + return 0; + }; + + OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; }; + OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; }; + OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; }; + OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; }; + + OrtCustomOp::CreateKernelV2 = {}; + OrtCustomOp::KernelComputeV2 = {}; + OrtCustomOp::KernelCompute = {}; + + OrtCustomOp::InferOutputShapeFn = {}; + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->end_ver_; + }; + + OrtCustomOp::GetMayInplace = {}; + OrtCustomOp::ReleaseMayInplace = {}; + OrtCustomOp::GetAliasMap = {}; + OrtCustomOp::ReleaseAliasMap = {}; + } + + const std::string op_name_; + const std::string execution_provider_; + + std::vector input_types_; + std::vector output_types_; + + ShapeInferFn shape_infer_fn_ = {}; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; + + void* compute_fn_ = {}; + void* compute_fn_return_status_ = {}; +}; + +//////////////////////////// OrtLiteCustomFunc //////////////////////////////// +// The struct is to implement function-as-op. +// E.g. a function might be defined as: +// void Filter(const Ort::Custom::Tensor& floats_in, Ort::Custom::Tensor& floats_out) { ... } +// It could be registered this way: +// Ort::CustomOpDomain v2_domain{"v2"}; +// std::unique_ptr fil_op_ptr{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)}; +// v2_domain.Add(fil_op_ptr.get()); +// session_options.Add(v2_domain); +// For the complete example, please search keyword "LiteCustomOpTest" under "/onnxruntime/test/". +template +struct OrtLiteCustomFunc : public OrtLiteCustomOp { + using ComputeFn = void (*)(Args...); + using ComputeFnReturnStatus = Status (*)(Args...); + using MyType = OrtLiteCustomFunc; + + struct Kernel { + size_t num_input_{}; + size_t num_output_{}; + ComputeFn compute_fn_{}; + ComputeFnReturnStatus compute_fn_return_status_{}; + std::string ep_{}; + }; + + OrtLiteCustomFunc(const char* op_name, + const char* execution_provider, + ComputeFn compute_fn, + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_ = reinterpret_cast(compute_fn); + ParseArgs(input_types_, output_types_); + + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { + auto kernel = reinterpret_cast(op_kernel); + std::vector args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t); + }; + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { + auto kernel = std::make_unique(); + auto me = static_cast(this_); + kernel->compute_fn_ = reinterpret_cast(me->compute_fn_); + Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); + Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); + auto self = static_cast(this_); + kernel->ep_ = self->execution_provider_; + return reinterpret_cast(kernel.release()); + }; + + OrtCustomOp::KernelDestroy = [](void* op_kernel) { + delete reinterpret_cast(op_kernel); + }; + + if (shape_infer_fn_) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + auto shape_info_fn = static_cast(op)->shape_infer_fn_; + ShapeInferContext ctx(&GetApi(), ort_ctx); + return shape_info_fn(ctx); + }; + } + } + + OrtLiteCustomFunc(const char* op_name, + const char* execution_provider, + ComputeFnReturnStatus compute_fn_return_status, + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_return_status_ = reinterpret_cast(compute_fn_return_status); + ParseArgs(input_types_, output_types_); + + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + auto kernel = reinterpret_cast(op_kernel); + std::vector args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t); + }; + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { + auto kernel = std::make_unique(); + auto me = static_cast(this_); + kernel->compute_fn_return_status_ = reinterpret_cast(me->compute_fn_return_status_); + Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); + Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); + auto self = static_cast(this_); + kernel->ep_ = self->execution_provider_; + return reinterpret_cast(kernel.release()); + }; + + OrtCustomOp::KernelDestroy = [](void* op_kernel) { + delete reinterpret_cast(op_kernel); + }; + + if (shape_infer_fn_) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + auto shape_info_fn = static_cast(op)->shape_infer_fn_; + ShapeInferContext ctx(&GetApi(), ort_ctx); + return shape_info_fn(ctx); + }; + } + } +}; // struct OrtLiteCustomFunc + +/////////////////////////// OrtLiteCustomStruct /////////////////////////// +// The struct is to implement struct-as-op. +// E.g. a struct might be defined as: +// struct Merge { +// Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...} +// void Compute(const Ort::Custom::Tensor& strings_in, +// std::string_view string_in, +// Ort::Custom::Tensor* strings_out) {...} +// bool reverse_ = false; +// }; +// It could be registered this way: +// Ort::CustomOpDomain v2_domain{"v2"}; +// std::unique_ptr mrg_op_ptr{Ort::Custom::CreateLiteCustomOp("Merge", "CPUExecutionProvider")}; +// v2_domain.Add(mrg_op_ptr.get()); +// session_options.Add(v2_domain); +// For the complete example, please search keyword "LiteCustomOpTest" under "/onnxruntime/test/". +template +struct OrtLiteCustomStruct : public OrtLiteCustomOp { + template + using CustomComputeFn = void (CustomOp::*)(Args...); + + template + using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...); + + using MyType = OrtLiteCustomStruct; + + struct Kernel { + size_t num_input_{}; + size_t num_output_{}; + std::unique_ptr custom_op_; + std::string ep_{}; + }; + + OrtLiteCustomStruct(const char* op_name, + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) { + SetCompute(&CustomOp::Compute); + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { + auto kernel = std::make_unique(); + Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); + Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); + kernel->custom_op_ = std::make_unique(ort_api, info); + auto self = static_cast(this_); + kernel->ep_ = self->execution_provider_; + return reinterpret_cast(kernel.release()); + }; + + OrtCustomOp::KernelDestroy = [](void* op_kernel) { + delete reinterpret_cast(op_kernel); + }; + + SetShapeInfer(0); + } + + template + void SetCompute(CustomComputeFn) { + ParseArgs(input_types_, output_types_); + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { + auto kernel = reinterpret_cast(op_kernel); + ArgPtrs args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); + }; + } + + template + void SetCompute(CustomComputeFnReturnStatus) { + ParseArgs(input_types_, output_types_); + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + auto kernel = reinterpret_cast(op_kernel); + ArgPtrs args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t); + }; + } + + template + decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) { + OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { + ShapeInferContext ctx(&GetApi(), ort_ctx); + return C::InferOutputShape(ctx); + }; + return {}; + } + + template + void SetShapeInfer(...) { + OrtCustomOp::InferOutputShapeFn = {}; + } +}; // struct OrtLiteCustomStruct + +/////////////////////////// CreateLiteCustomOp //////////////////////////// + +template +OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, + const char* execution_provider, + void (*custom_compute_fn)(Args...), + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { + using LiteOp = OrtLiteCustomFunc; + return std::make_unique(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release(); +} + +template +OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, + const char* execution_provider, + Status (*custom_compute_fn_v2)(Args...), + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { + using LiteOp = OrtLiteCustomFunc; + return std::make_unique(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release(); +} + +template +OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { + using LiteOp = OrtLiteCustomStruct; + return std::make_unique(op_name, execution_provider, start_ver, end_ver).release(); +} + +} // namespace Custom +} // namespace Ort diff --git a/libs/onnxruntime/include/onnxruntime_run_options_config_keys.h b/libs/onnxruntime/include/onnxruntime_run_options_config_keys.h index 7ae8480..31c6126 100644 --- a/libs/onnxruntime/include/onnxruntime_run_options_config_keys.h +++ b/libs/onnxruntime/include/onnxruntime_run_options_config_keys.h @@ -25,3 +25,27 @@ // Example usage: "cpu:0;gpu:0" (or) "gpu:0" // By default, the value for this key is empty (i.e.) no memory arenas are shrunk static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage"; + +// Set to '1' to not synchronize execution providers with CPU at the end of session run. +// Per default it will be set to '0' +// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. +static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; + +// Set HTP performance mode for QNN HTP backend before session run. +// options for HTP performance mode: "burst", "balanced", "default", "high_performance", +// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", +// "sustained_high_performance". Default to "default". +static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode"; + +// Set HTP performance mode for QNN HTP backend post session run. +static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run"; + +// Set RPC control latency for QNN HTP backend +static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency"; + +// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true. +// The value should be an integer. If the value is not set, the default value is 0 and +// ORT session only captures one cuda graph before another capture is requested. +// If the value is set to -1, cuda graph capture/replay is disabled in that run. +// User are not expected to set the value to 0 as it is reserved for internal use. +static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id"; diff --git a/libs/onnxruntime/include/onnxruntime_session_options_config_keys.h b/libs/onnxruntime/include/onnxruntime_session_options_config_keys.h index 948d484..a9216c3 100644 --- a/libs/onnxruntime/include/onnxruntime_session_options_config_keys.h +++ b/libs/onnxruntime/include/onnxruntime_session_options_config_keys.h @@ -44,13 +44,69 @@ static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.se // It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not. // "0": enable. ORT does fusion logic for QDQ format. // "1": disable. ORT doesn't do fusion logic for QDQ format. -// Its default value is "0" +// Its default value is "0" unless the DirectML execution provider is registered, in which case it defaults to "1". static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq"; +// It controls whether to enable Double QDQ remover and Identical Children Consolidation +// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs +// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs +// Its default value is "0" +static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover"; + +// If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been +// completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the +// Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to +// 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on +// other factors like whether the model was created using Quantization Aware Training or Post Training Quantization. +// As such, it's best to test to determine if enabling this works well for your scenario. +// The default value is "0" +// Available since version 1.11. +static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup"; + // Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0". // GeluApproximation has side effects which may change the inference results. It is disabled by default due to this. static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation"; +// This setting controls whether to enable AheadOfTime function inlining. +// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model +// as possible with the help of enabled execution providers. +// This can reduce the number of function calls and improve performance because it is done before +// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available, +// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time. +// "0": enable; "1": disable. +// Its default value is "0". +static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining"; + +#ifdef ENABLE_TRAINING +// Specifies a path of the file containing a list of memory optimization configurations. +// The value should be a string indicating the file path of the config file. +// The content of the config file is a JSON struct like this: +// [ +// "Gelu+Cast+:1:0", +// "Dropout+:1:1" +// ] +// Taking the example of "Gelu+Cast+:1:0", +// > "Gelu+Cast+" is the subgraph string, a valid "subgraph string" should be one subgraph representation +// output by ORT graph transformations. +// > "1" is "optimization strategy", valid values: 0 - disabled, 1 - recompute. +// > "0" is "number of subgraph to apply" which is used to control how many subgraphs to apply optimization, +// to avoid "oversaving" the memory. +static const char* const kOrtSessionOptionsMemoryOptimizerApplyConfig = "optimization.memory_optimizer_config"; + +// Specifies the config for detecting subgraphs for memory footprint reduction. +// The value should be a string contains int separated using commas. The default value is "0:0". +static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config"; +#endif + +// This setting if set should contain a comma separated list of optimizers names that should be disabled. +// Optimizers may take time to execute and affect model loading time. If you feel that a specific optimizer +// does not provider runtime benefits, but affects your model loading time you may disable it using this config +// entry. This option is not enabled in ORT_MINIMAL_BUILD build. +// A list of optimizes is available in onnxruntime/core/optimizer/graph_transformer_utils.cc +// +// Default is an empty string which means no optimizers are disabled. +static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers"; + // Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0". // Using device allocators means the memory allocation is made using malloc/new. static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers"; @@ -69,23 +125,36 @@ static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session // has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed. static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly"; -// Save information for replaying graph optimizations later instead of applying them directly. -// -// When an ONNX model is loaded, ORT can perform various optimizations on the graph. -// However, when an ORT format model is loaded, these optimizations are typically not available - this scenario must -// be supported by minimal builds. -// When loading an ONNX model, ORT can optionally save the effects of some optimizations for later replay in an ORT -// format model. These are known as "runtime optimizations" - in an ORT format model, they happen at runtime. -// -// Note: This option is only applicable when loading an ONNX model and saving an ORT format model. -// -// Note: Runtime optimizations are only supported for certain optimizations at the extended level or higher. -// Unsupported optimizations at those levels are not applied at all, while optimizations at other levels are applied -// directly. -// -// "0": disabled, "1": enabled -// The default is "0". -static const char* const kOrtSessionOptionsConfigSaveRuntimeOptimizations = "optimization.save_runtime_optimizations"; +/// +/// Key for using the ORT format model flatbuffer bytes directly for initializers. +/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization. +/// Requires `session.use_ort_model_bytes_directly` to be true. +/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire +/// duration of the InferenceSession. +/// +static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers = + "session.use_ort_model_bytes_for_initializers"; + +// This should only be specified when exporting an ORT format model for use on a different platform. +// If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0" +// Available since version 1.11. +static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed"; + +// x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8. +// To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if +// turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512 +// platforms. +static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision"; + +// Specifies how minimal build graph optimizations are handled in a full build. +// These optimizations are at the extended level or higher. +// Possible values and their effects are: +// "save": Save runtime optimizations when saving an ORT format model. +// "apply": Only apply optimizations available in a minimal build. +// ""/: Apply optimizations available in a full build. +// Available since version 1.11. +static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations = + "optimization.minimal_build_optimizations"; // Note: The options specific to an EP should be specified prior to appending that EP to the session options object in // order for them to take effect. @@ -96,3 +165,127 @@ static const char* const kOrtSessionOptionsConfigSaveRuntimeOptimizations = "opt // If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op // exclusion, set the value to "". static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops"; + +// Enabling dynamic block-sizing for multithreading. +// With a positive value, thread pool will split a task of N iterations to blocks of size starting from: +// N / (num_of_threads * dynamic_block_base) +// As execution progresses, the size will decrease according to the diminishing residual of N, +// meaning the task will be distributed in smaller granularity for better parallelism. +// For some models, it helps to reduce the variance of E2E inference latency and boost performance. +// The feature will not function by default, specify any positive integer, e.g. "4", to enable it. +// Available since version 1.11. +static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base"; + +// This option allows to decrease CPU usage between infrequent +// requests and forces any TP threads spinning stop immediately when the last of +// concurrent Run() call returns. +// Spinning is restarted on the next Run() call. +// Applies only to internal thread-pools +static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop"; + +// "1": all inconsistencies encountered during shape and type inference +// will result in failures. +// "0": in some cases warnings will be logged but processing will continue. The default. +// May be useful to expose bugs in models. +static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference"; + +// "1": every model using a more recent opset than the latest released one will fail +// "0": the model may or may not work if onnxruntime cannot find an implementation, this option +// is used for development purpose. +static const char* const kOrtSessionOptionsConfigStrictAllowReleasedOpsetsOnly = "session.allow_released_opsets_only"; + +// The file saves configuration for partitioning node among logic streams +static const char* const kNodePartitionConfigFile = "session.node_partition_config_file"; + +// This Option allows setting affinities for intra op threads. +// Affinity string follows format: +// logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id +// Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to. +// e.g.1,2,3;4,5 +// specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th. +// To ease the configuration, an "interval" is also allowed: +// e.g. 1-8;8-16;17-24 +// orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth. +// Note: +// 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which +// is started and managed by the calling app; +// 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors, +// an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group. +// Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary. +static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities"; + +// This option will dump out the model to assist debugging any issues with layout transformation, +// and is primarily intended for developer usage. It is only relevant if an execution provider that requests +// NHWC layout is enabled such as NNAPI, XNNPACK or QNN. +// +// Default is off. Set to "1" to enable. +// +// If modified by layout transformation the model will be dumped after these steps: +// 1) insertion of the layout transformation Transpose nodes +// 2) after those are optimized using the transpose optimizer, +// 3) after the L1 transformers are applied to the updated graph. +// The model will be saved to filename post_layout_transform_step_.onnx. +static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation"; + +// Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are +// assigned (i.e., "fallback") to the CPU EP by default. +// +// This option allows the user to disable the fallback of unsupported graph nodes to the CPU EP. +// If this option is set to "1", session creation will fail if the execution providers other than the CPU EP cannot +// fully support all of the nodes in the graph. +// +// It is invalid to set this option and explicitly add the CPU EP to the session. In this case, session creation +// will also fail with an error. +// +// Option values: +// - "0": CPU EP fallback is not disabled. [DEFAULT] +// - "1": CPU EP fallback is disabled. +static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disable_cpu_ep_fallback"; + +// Use this config when serializing a large model after optimization to specify an external initializers file +static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName = + "session.optimized_model_external_initializers_file_name"; + +// Use this config to control the minimum size of the initializer when externalizing it during serialization +static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = + "session.optimized_model_external_initializers_min_size_in_bytes"; + +// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file. +// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. +// "0": disable. (default) +// "1": enable. +static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable"; + +// Specify the file path for the Onnx model which has EP context. +// Default to original_file_name_ctx.onnx if not specified +static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path"; + +// Flag to specify whether to dump the EP context into the Onnx model. +// "0": dump the EP context into separate file, keep the file name in the Onnx model. +// "1": dump the EP context into the Onnx model. (default). +static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; + +// Specify the EPContext node name prefix to make it unique +// in case user need to merge/connect multiple EPContext nodes in one model +static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix"; + +// Share EP related resources across EPs +static const char* const kOrtSessionOptionShareEpContexts = "ep.share_ep_contexts"; + +// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. +// Option values: +// - "0": Gemm FastMath mode is not enabled. [DEFAULT] +// - "1": Gemm FastMath mode is enabled. +static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; + +// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. +// Refer to MatMulNBits op schema for more details. +// If not provided, default is 4. +static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; + +// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME +// Meant to be used with SetEpDynamicOptions +// Specify the type of workload for this session. +// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default] +// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. +static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type"; diff --git a/libs/onnxruntime/include/onnxruntime_training_c_api.h b/libs/onnxruntime/include/onnxruntime_training_c_api.h new file mode 100644 index 0000000..ed6d151 --- /dev/null +++ b/libs/onnxruntime/include/onnxruntime_training_c_api.h @@ -0,0 +1,731 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file contains the training c apis. + +#pragma once +#include +#include "onnxruntime_c_api.h" + +/** \page training_c_cpp_api Training C & C++ APIs + * + * Training C and C++ APIs are an extension of the \ref c_cpp_api "onnxruntime core C and C++ APIs" and should be used in conjunction with them. + * + * In order to train a model with onnxruntime, the following training artifacts must be generated: + * - The training onnx model + * - The checkpoint file + * - The optimizer onnx model + * - The eval onnx model model (optional) + * + * These training artifacts can be generated as part of an offline step using the python [utilities](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md) made available in the `onnxruntime-training` python package. + * + * After these artifacts have been generated, the C and C++ utilities listed in this documentation can be leveraged to perform training. + * + * If any problem is encountered, please create an [issue](https://github.com/microsoft/onnxruntime/issues/new) with your scenario and requirements, and we will be sure to respond and follow up on the request. + * + *

Training C API

+ * + * ::OrtTrainingApi - Training C API functions. + * + * This C structure contains functions that enable users to perform training with onnxruntime. + * + * _Sample Code_: + * + * ```c + * #include + * + * OrtApi* g_ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + * OrtTrainingApi* g_ort_training_api = g_ort_api->GetTrainingApi(ORT_API_VERSION); + * + * OrtEnv* env = NULL; + * g_ort_api->CreateEnv(logging_level, logid, &env); + * OrtSessionOptions* session_options = NULL; + * g_ort_api->CreateSessionOptions(&session_options); + * + * OrtCheckpointState* state = NULL; + * g_ort_training_api->LoadCheckpoint(path_to_checkpoint, &state); + * + * OrtTrainingSession* training_session = NULL; + * g_ort_training_api->CreateTrainingSession(env, session_options, training_model_path, + * state, eval_model_path, optimizer_model_path, + * &training_session); + * // Training loop + * { + * g_ort_training_api->TrainStep(...); + * g_ort_training_api->OptimizerStep(...); + * g_ort_training_api->LazyResetGrad(...); + * } + * + * g_ort_training_api->ExportModelForInferencing(training_session, inference_model_path, ...); + * g_ort_training_api->SaveCheckpoint(state, path_to_checkpoint, false); + * + * g_ort_training_api->ReleaseTrainingSession(training_session); + * g_ort_training_api->ReleaseCheckpointState(state); + * ``` + * + * > **Note** + * > The ::OrtCheckpointState contains the entire training state that the ::OrtTrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::OrtCheckpointState instance must outlive the lifetime of the ::OrtTrainingSession instance. + * + *

Training C++ API

+ * + * @ref TrainingCpp - Training C++ API classes and functions. + * + * These C++ classes and functions enable users to perform training with onnxruntime. + * + * _Sample Code_: + * + * ```cc + * #include + * + * Ort::Env env; + * Ort::SessionOptions session_options; + * + * auto state = Ort::CheckpointState::LoadCheckpoint(path_to_checkpoint); + * auto training_session = Ort::TrainingSession(env, session_options, state, training_model_path, + * eval_model_path, optimizer_model_path); + * + * // Training Loop + * { + * training_session.TrainStep(...); + * training_session.OptimizerStep(...); + * training_session.LazyResetGrad(...); + * } + * + * training_session->ExportModelForInferencing(inference_model_path, ...); + * Ort::CheckpointState::SaveCheckpoint(state, path_to_checkpoint, false); + * ``` + * > **Note** + * > The ::Ort::CheckpointState contains the entire training state that the ::Ort::TrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::Ort::CheckpointState instance must outlive the lifetime of the ::Ort::TrainingSession instance. + */ + +/** @defgroup TrainingC Ort Training C API + * @{ + */ +ORT_RUNTIME_CLASS(TrainingSession); // Type that enables performing training for the given user models. +ORT_RUNTIME_CLASS(CheckpointState); // Type that holds the training states for the training session. + +/** \brief Type of property to be added to or returned from the ::OrtCheckpointState. + */ +typedef enum OrtPropertyType { + OrtIntProperty = 0, + OrtFloatProperty = 1, + OrtStringProperty = 2, +} OrtPropertyType; + +/** \brief The Training C API that holds onnxruntime training function pointers + * + * All the Training C API functions are defined inside this structure as pointers to functions. + * Call OrtApi::GetTrainingApi to get a pointer to this struct. + * + * \nosubgrouping + */ +struct OrtTrainingApi { + /// \name Accessing The Training Session State + /// @{ + + /** \brief Load a checkpoint state from a file on disk into checkpoint_state. + * + * This function will parse a checkpoint file, pull relevant data and load the training + * state into the checkpoint_state. This checkpoint state can then be used to create the + * training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training + * session will resume training from the given checkpoint state. + * \note Note that the training session created with a checkpoint state uses this state to store the entire + * training state (including model parameters, its gradients, the optimizer states and the properties). + * As a result, it is required that the checkpoint state outlive the lifetime of the training session. + * \note Note that the checkpoint file can be either the complete checkpoint or the nominal checkpoint. + * + * \param[in] checkpoint_path Path to the checkpoint file + * \param[out] checkpoint_state Checkpoint state that contains the states of the training session. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, + _Outptr_ OrtCheckpointState** checkpoint_state); + + /** \brief Save the given state to a checkpoint file on disk. + * + * This function serializes the provided checkpoint state to a file on disk. + * This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume + * training from this snapshot of the state. + * + * \param[in] checkpoint_state The checkpoint state to save. + * \param[in] checkpoint_path Path to the checkpoint file. + * \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(SaveCheckpoint, _In_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* checkpoint_path, + const bool include_optimizer_state); + + /// @} + + /// \name Implementing The Training Loop + /// @{ + /** \brief Create a training session that can be used to begin or resume training. + * + * This function creates a training session based on the env and session options provided that can + * begin or resume training from a given checkpoint state for the given onnx models. + * The checkpoint state represents the parameters of the training session which will be moved + * to the device specified by the user through the session options (if necessary). + * The training session requires four training artifacts + * - The training onnx model + * - The evaluation onnx model (optional) + * - The optimizer onnx model + * - The checkpoint file + * + * These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md). + * + * \param[in] env Environment to be used for the training session. + * \param[in] options Session options that the user can customize for this training session. + * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. + * \param[in] train_model_path Model to be used to perform training. + * \param[in] eval_model_path Model to be used to perform evaluation. + * \param[in] optimizer_model_path Model to be used to perform gradient descent. + * \param[out] out Created training session. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, + _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, + _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, + _Outptr_result_maybenull_ OrtTrainingSession** out); + + /** \brief Create a training session that can be used to begin or resume training. + * This api provides a way to load all the training artifacts from buffers instead of files. + * + * \param[in] env Environment to be used for the training session. + * \param[in] options Session options that the user can customize for this training session. + * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. + * \param[in] train_model_data Buffer containing the model data to be used to perform training + * \param[in] train_data_length Length of the buffer containing train_model_data + * \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation + * \param[in] eval_data_length Length of the buffer containing eval_model_data + * \param[in] optim_model_data Buffer containing the model data to be used to perform weight update + * \param[in] optim_data_length Length of the buffer containing optim_model_data + * \param[out] out Created training session. + * + */ + ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out); + + /// @} + + /// \name Model IO Information + /// @{ + + /** \brief Retrieves the number of user outputs in the training model. + * + * This function returns the number of outputs of the training model so that the user can + * allocate space for the number of outputs when OrtTrainingApi::TrainStep is invoked. + * + * \param[in] sess The `this` pointer to the training session. + * \param[out] out Number of user outputs in the training model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); + + /** \brief Retrieves the number of user outputs in the eval model. + * + * This function returns the number of outputs of the eval model so that the user can + * allocate space for the number of outputs when OrtTrainingApi::EvalStep is invoked. + * + * \param[in] sess The `this` pointer to the training session. + * \param[out] out Number of user outputs in the eval model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(TrainingSessionGetEvalModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); + + /** \brief Retrieves the names of user outputs in the training model. + * + * This function returns the names of outputs of the training model that can be associated with the OrtValue(s) + * returned by the OrtTrainingApi::TrainStep function. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] index Index of the output name requested. + * \param[in] allocator Allocator to use to allocate the memory for the name. + * \param[out] output Name of the training model output at the given index. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output); + + /** \brief Retrieves the names of user outputs in the eval model. + * + * This function returns the names of outputs of the eval model that can be associated with the OrtValue(s) returned + * by the OrtTrainingApi::EvalStep function. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] index Index of the output name requested. + * \param[in] allocator Allocator to use to allocate the memory for the name. + * \param[out] output Name of the eval model output at the given index. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(TrainingSessionGetEvalModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output); + + /// @} + + /// \name Implementing The Training Loop + /// @{ + + /** \brief Reset the gradients of all trainable parameters to zero lazily. + * + * This function sets the internal state of the training session such that the gradients of the trainable + * parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are + * computed on the next invocation of the next OrtTrainingApi::TrainStep. + * + * \param[in] session The `this` pointer to the training session. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(LazyResetGrad, _Inout_ OrtTrainingSession* session); + + /** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs + * + * This function performs a training step that computes the outputs of the training model and the gradients + * of the trainable parameters for the given inputs. The train step is performed based on the training model + * that was provided to the training session. + * The OrtTrainingApi::TrainStep is equivalent of running forward propagation and backward propagation in a single + * step. + * The gradients computed are stored inside the training session state so they can be later consumed + * by the OrtTrainingApi::OptimizerStep function. + * The gradients can be lazily reset by invoking the OrtTrainingApi::LazyResetGrad function. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] run_options Run options for this training step. + * \param[in] inputs_len Number of user inputs to the training model. + * \param[in] inputs The user inputs to the training model. + * \param[in] outputs_len Number of user outputs expected from this training step. + * \param[out] outputs User outputs computed by train step. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, + _In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, + _In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); + + /** \brief Computes the outputs for the eval model for the given inputs + * + * This function performs an eval step that computes the outputs of the eval model for the given inputs. + * The eval step is performed based on the eval model that was provided to the training session. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] run_options Run options for this eval step. + * \param[in] inputs_len Number of user inputs to the eval model. + * \param[in] inputs The user inputs to the eval model. + * \param[in] outputs_len Number of user outputs expected from this eval step. + * \param[out] outputs User outputs computed by eval step. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(EvalStep, _In_ const OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, + _In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, + _In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); + + /** \brief Sets the learning rate for this training session. + * + * This function allows users to set the learning rate for the training session. The current + * learning rate is maintained by the training session and can be overwritten by invoking + * this function with the desired learning rate. This function should not be used when a valid + * learning rate scheduler is registered. It should be used either to set the learning rate + * derived from a custom learning rate scheduler or to set a constant learning rate to be used + * throughout the training session. + * \note Please note that this function does not set the initial learning rate that may be needed + * by the predefined learning rate schedulers. To set the initial learning rate for learning + * rate schedulers, please look at the function OrtTrainingApi::RegisterLinearLRScheduler. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] learning_rate Desired learning rate to be set. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate); + + /** \brief Gets the current learning rate for this training session. + * + * This function allows users to get the learning rate for the training session. The current + * learning rate is maintained by the training session, and users can query it for the purpose + * of implementing their own learning rate schedulers. + * + * \param[in] sess The `this` pointer to the training session. + * \param[out] learning_rate Learning rate currently in use by the training session. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetLearningRate, _Inout_ OrtTrainingSession* sess, _Out_ float* learning_rate); + + /** \brief Performs the weight updates for the trainable parameters using the optimizer model. + * + * This function performs the weight update step that updates the trainable parameters such that they + * take a step in the direction of their gradients (gradient descent). The optimizer step is performed + * based on the optimizer model that was provided to the training session. + * The updated parameters are stored inside the training state so that they can be used by the next + * OrtTrainingApi::TrainStep function call. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] run_options Run options for this optimizer step. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess, + _In_opt_ const OrtRunOptions* run_options); + + /** \brief Registers a linear learning rate scheduler for the training session. + * + * Register a linear learning rate scheduler that decays the learning rate by linearly updated + * multiplicative factor from the initial learning rate set on the training session to 0. The decay + * is performed after the initial warm up phase where the learning rate is linearly incremented + * from 0 to the initial learning rate provided. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] warmup_step_count Warmup steps for LR warmup. + * \param[in] total_step_count Total step count. + * \param[in] initial_lr The initial learning rate to be used by the training session. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(RegisterLinearLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ const int64_t warmup_step_count, + _In_ const int64_t total_step_count, _In_ const float initial_lr); + + /** \brief Update the learning rate based on the registered learing rate scheduler. + * + * Takes a scheduler step that updates the learning rate that is being used by the training session. + * This function should typically be called before invoking the optimizer step for each round, + * or as determined necessary to update the learning rate being used by the training session. + * \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this + * function. + * + * \param[in] sess The `this` pointer to the training session. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess); + + /// @} + + /// \name Accessing The Training Session State + /// @{ + /** \brief Retrieves the size of all the parameters. + * + * Calculates the total number of primitive (datatype of the parameters) elements of all the parameters in the + * training state. + * When trainable_only argument is true, the size is calculated for trainable params only. + * + * \param[in] sess The `this` pointer to the training session. + * \param[out] out Size of all parameter elements. + * \param[in] trainable_only Whether to skip non-trainable parameters + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only); + + /** \brief Copy all parameters to a contiguous buffer held by the argument parameters_buffer + * + * The parameters_buffer has to be of the size given by GetParametersSize api call, + * with matching setting for the argument trainable_only. All the target parameters must be of the same + * datatype. The OrtValue must be pre-allocated onto + * the desired device. This is a complementary function to OrtTrainingApi::CopyBufferToParameters. + * Parameter ordering is preserved. + * User is responsible for allocating and freeing the resources used by the parameters_buffer. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] trainable_only Whether to skip non-trainable parameters + * \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy onto. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(CopyParametersToBuffer, _Inout_ OrtTrainingSession* sess, + _Inout_ OrtValue* parameters_buffer, bool trainable_only); + + /** \brief Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state + * + * The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call, + * with matching setting for trainable_only argument. All the target parameters must be of the same + * datatype. This is a complementary function to OrtTrainingApi::CopyParametersToBuffer + * and can be used to load updated buffer values onto the training state. + * Parameter ordering is preserved. + * User is responsible for allocating and freeing the resources used by the parameters_buffer. + * In case the training session was created with a nominal checkpoint, invoking this function is required + * to load the updated parameters onto the checkpoint to complete it. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] trainable_only Whether to skip non-trainable parameters + * \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy from. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(CopyBufferToParameters, _Inout_ OrtTrainingSession* sess, + _Inout_ OrtValue* parameters_buffer, bool trainable_only); + + /// @} + + /// \name Release Training Resources + /// @{ + + /** \brief Frees up the memory used up by the training session. + * + * This function frees up any memory that was allocated in the training session. The training + * session can no longer be used after this call. + * + */ + ORT_CLASS_RELEASE(TrainingSession); + + /** \brief Frees up the memory used up by the checkpoint state. + * + * This function frees up any memory that was allocated in the checkpoint state. The checkpoint + * state can no longer be used after this call. + * \note Note that the checkpoint state must be released only after the training session has been released. + * + */ + ORT_CLASS_RELEASE(CheckpointState); + + /// @} + + /// \name Prepare For Inferencing + /// @{ + /** \brief Export a model that can be used for inferencing. + * + * If the training session was provided with an eval model, the training session can generate + * an inference model if it knows the inference graph outputs. The input inference graph outputs + * are used to prune the eval model so that the inference model's outputs align with the provided outputs. + * The exported model is saved at the path provided and can be used for inferencing with InferenceSession. + * \note Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession + * and expects that this path still be valid. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] inference_model_path Path where the inference model should be serialized to. + * \param[in] graph_outputs_len Size of the graph output names array. + * \param[in] graph_output_names Names of the outputs that are needed in the inference model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess, + _In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len, + _In_reads_(graph_outputs_len) const char* const* graph_output_names); + + /// @} + + /// \name Training Utilities + /// @{ + /** \brief Sets the seed used for random number generation in Onnxruntime. + * + * Use this function to generate reproducible results. It should be noted that completely reproducible + * results are not guaranteed. + * + * \param[in] seed The seed to be set. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(SetSeed, _In_ const int64_t seed); + + /// @} + + /// \name Model IO Information + /// @{ + /** \brief Retrieves the number of user inputs in the training model. + * + * This function returns the number of inputs of the training model so that the user can accordingly + * allocate the OrtValue(s) provided to the OrtTrainingApi::TrainStep function. + * + * \param[in] sess The `this` pointer to the training session. + * \param[out] out Number of user inputs in the training model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(TrainingSessionGetTrainingModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); + + /** \brief Retrieves the number of user inputs in the eval model. + * + * This function returns the number of inputs of the eval model so that the user can accordingly + * allocate the OrtValue(s) provided to the OrtTrainingApi::EvalStep function. + * + * \param[in] sess The `this` pointer to the training session. + * \param[out] out Number of user inputs in the eval model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(TrainingSessionGetEvalModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); + + /** \brief Retrieves the name of the user input at given index in the training model. + * + * This function returns the names of inputs of the training model that can be associated with the + * OrtValue(s) provided to the OrtTrainingApi::TrainStep function. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] index The index of the training model input name requested. + * \param[in] allocator The allocator to use to allocate the memory for the requested name. + * \param[out] output Name of the user input for the training model at the given index. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(TrainingSessionGetTrainingModelInputName, _In_ const OrtTrainingSession* sess, size_t index, + _In_ OrtAllocator* allocator, _Outptr_ char** output); + + /** \brief Retrieves the name of the user input at given index in the eval model. + * + * This function returns the names of inputs of the eval model that can be associated with the OrtValue(s) provided + * to the OrtTrainingApi::EvalStep function. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] index The index of the eval model input name requested. + * \param[in] allocator The allocator to use to allocate the memory for the requested name. + * \param[out] output Name of the user input for the eval model at the given index. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(TrainingSessionGetEvalModelInputName, _In_ const OrtTrainingSession* sess, size_t index, + _In_ OrtAllocator* allocator, _Outptr_ char** output); + + /// @} + + /// \name Accessing The Training Session State + /// @{ + + /** \brief Adds or updates the given property to/in the checkpoint state. + * + * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. + * + * \param[in] checkpoint_state The checkpoint state which should hold the property. + * \param[in] property_name Name of the property being added or updated. + * \param[in] property_type Type of the property associated with the given name. + * \param[in] property_value Property value associated with the given name. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* property_name, _In_ enum OrtPropertyType property_type, + _In_ void* property_value); + + /** \brief Gets the property value associated with the given name from the checkpoint state. + * + * Gets the property value from an existing entry in the checkpoint state. The property must + * exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] checkpoint_state The checkpoint state that is currently holding the property. + * \param[in] property_name Name of the property being retrieved. + * \param[in] allocator Allocator used to allocate the memory for the property_value. + * \param[out] property_type Type of the property associated with the given name. + * \param[out] property_value Property value associated with the given name. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* property_name, _Inout_ OrtAllocator* allocator, + _Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value); + + /// @} + + /// \name Accessing The Training Session State + /// @{ + + /** \brief Load a checkpoint state from a buffer into checkpoint_state. + * + * This function will parse a checkpoint bytes buffer, pull relevant data and load the training + * state into the checkpoint_state. This checkpoint state can then be used to create the + * training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training + * session will resume training from the given checkpoint state. + * \note Note that the training session created with a checkpoint state uses this state to store the entire + * training state (including model parameters, its gradients, the optimizer states and the properties). + * As a result, it is required that the checkpoint state outlive the lifetime of the training session. + * + * \param[in] checkpoint_buffer Path to the checkpoint bytes buffer. + * \param[in] num_bytes Number of bytes in the checkpoint buffer. + * \param[out] checkpoint_state Checkpoint state that contains the states of the training session. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, + _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); + + /** \brief Retrieves the type and shape information of the parameter associated with the given parameter name. + * + * This function retrieves the type and shape of the parameter associated with the given parameter name. + * The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over and returned as an OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[in] allocator Allocator used to allocate the memory for the parameter. + * \param[out] parameter The parameter data that is retrieved from the checkpoint state. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); + + /// @} +}; + +typedef struct OrtTrainingApi OrtTrainingApi; + +/// @} diff --git a/libs/onnxruntime/include/onnxruntime_training_cxx_api.h b/libs/onnxruntime/include/onnxruntime_training_cxx_api.h new file mode 100644 index 0000000..e78c161 --- /dev/null +++ b/libs/onnxruntime/include/onnxruntime_training_cxx_api.h @@ -0,0 +1,418 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "onnxruntime_training_c_api.h" +#include +#include + +namespace Ort::detail { + +#define ORT_DECLARE_TRAINING_RELEASE(NAME) \ + void OrtRelease(Ort##NAME* ptr); + +// These release methods must be forward declared before including onnxruntime_cxx_api.h +// otherwise class Base won't be aware of them +ORT_DECLARE_TRAINING_RELEASE(CheckpointState); +ORT_DECLARE_TRAINING_RELEASE(TrainingSession); + +} // namespace Ort::detail + +#include "onnxruntime_cxx_api.h" + +namespace Ort { + +/// +/// This function returns the C training api struct with the pointers to the ort training C functions. +/// If using C++, please use the class instances instead of invoking the C functions directly. +/// +/// OrtTrainingApi struct with ort training C function pointers. +inline const OrtTrainingApi& GetTrainingApi() { return *GetApi().GetTrainingApi(ORT_API_VERSION); } + +namespace detail { + +#define ORT_DEFINE_TRAINING_RELEASE(NAME) \ + inline void OrtRelease(Ort##NAME* ptr) { GetTrainingApi().Release##NAME(ptr); } + +ORT_DEFINE_TRAINING_RELEASE(CheckpointState); +ORT_DEFINE_TRAINING_RELEASE(TrainingSession); + +#undef ORT_DECLARE_TRAINING_RELEASE +#undef ORT_DEFINE_TRAINING_RELEASE + +} // namespace detail + +using Property = std::variant; + +/** + * \defgroup TrainingCpp Ort Training C++ API + * @{ + */ + +/** \brief Holds the state of the training session. + * + * This class holds the entire training session state that includes model parameters, their gradients, + * optimizer parameters, and user properties. The Ort::TrainingSession leverages the Ort::CheckpointState + * by accessing and updating the contained training state. + * \note Note that the training session created with a checkpoint state uses this state to store the entire + * training state (including model parameters, its gradients, the optimizer states and the properties). + * The Ort::TrainingSession does not hold a copy of the Ort::CheckpointState and as a result, it is required + * that the checkpoint state outlive the lifetime of the training session. + * \note Note that the checkpoint state can be either the complete checkpoint state or the nominal checkpoint + * state depending on the version provided while loading the checkpoint. + * + */ +class CheckpointState : public detail::Base { + private: + CheckpointState(OrtCheckpointState* checkpoint_state) { p_ = checkpoint_state; } + + public: + // Construct the checkpoint state by loading the checkpoint by calling LoadCheckpoint + CheckpointState() = delete; + + /// \name Accessing The Training Session State + /// @{ + + /** \brief Load a checkpoint state from a file on disk into checkpoint_state. + * + * This function will parse a checkpoint file, pull relevant data and load the training + * state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the + * training session by instantiating Ort::TrainingSession. By doing so, the training session will resume + * training from the given checkpoint state. + * + * \param[in] path_to_checkpoint Path to the checkpoint file + * \return Ort::CheckpointState object which holds the state of the training session parameters. + * + */ + static CheckpointState LoadCheckpoint(const std::basic_string& path_to_checkpoint); + + /** \brief Load a checkpoint state from a buffer. + * + * This function will parse a checkpoint buffer, pull relevant data and load the training + * state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the + * training session by instantiating Ort::TrainingSession. By doing so, the training session will resume + * training from the given checkpoint state. + * + * \param[in] buffer Buffer containing the checkpoint data. + * \return Ort::CheckpointState object which holds the state of the training session parameters. + * + */ + static CheckpointState LoadCheckpointFromBuffer(const std::vector& buffer); + + /** \brief Save the given state to a checkpoint file on disk. + * + * This function serializes the provided checkpoint state to a file on disk. + * This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume + * training from this snapshot of the state. + * + * \param[in] checkpoint_state The checkpoint state to save. + * \param[in] path_to_checkpoint Path to the checkpoint file. + * \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not. + * + */ + static void SaveCheckpoint(const CheckpointState& checkpoint_state, + const std::basic_string& path_to_checkpoint, + const bool include_optimizer_state = false); + + /** \brief Adds or updates the given property to/in the checkpoint state. + * + * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. + * + * \param[in] property_name Name of the property being added or updated. + * \param[in] property_value Property value associated with the given name. + * + */ + void AddProperty(const std::string& property_name, const Property& property_value); + + /** \brief Gets the property value associated with the given name from the checkpoint state. + * + * Gets the property value from an existing entry in the checkpoint state. The property must + * exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] property_name Name of the property being retrieved. + * \return Property value associated with the given property name. + * + */ + Property GetProperty(const std::string& property_name); + + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + */ + void UpdateParameter(const std::string& parameter_name, const Value& parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over to the provided OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] parameter_name Name of the parameter being retrieved. + * \return The parameter data that is retrieved from the checkpoint state. + * + */ + Value GetParameter(const std::string& parameter_name); + + /// @} +}; + +/** \brief Trainer class that provides training, evaluation and optimizer methods for training an ONNX models. + * + * The training session requires four training artifacts + * - The training onnx model + * - The evaluation onnx model (optional) + * - The optimizer onnx model + * - The checkpoint file + * + * These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md). + * + */ +class TrainingSession : public detail::Base { + private: + size_t training_model_output_count_, eval_model_output_count_; + + public: + /// \name Constructing the Training Session + /// @{ + /** \brief Create a training session that can be used to begin or resume training. + * + * This constructor instantiates the training session based on the env and session options provided that can + * begin or resume training from a given checkpoint state for the given onnx models. + * The checkpoint state represents the parameters of the training session which will be moved + * to the device specified by the user through the session options (if necessary). + * + * \param[in] env Env to be used for the training session. + * \param[in] session_options SessionOptions that the user can customize for this training session. + * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. + * \param[in] train_model_path Model to be used to perform training. + * \param[in] eval_model_path Model to be used to perform evaluation. + * \param[in] optimizer_model_path Model to be used to perform gradient descent. + * + */ + TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state, + const std::basic_string& train_model_path, + const std::optional>& eval_model_path = std::nullopt, + const std::optional>& optimizer_model_path = std::nullopt); + + /** \brief Create a training session that can be used to begin or resume training. + * This constructor allows the users to load the models from buffers instead of files. + * + * \param[in] env Env to be used for the training session. + * \param[in] session_options SessionOptions that the user can customize for this training session. + * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. + * \param[in] train_model_data Buffer containing training model data. + * \param[in] eval_model_data Buffer containing evaluation model data. + * \param[in] optim_model_data Buffer containing optimizer model (used for performing weight/parameter update). + * + */ + TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state, + const std::vector& train_model_data, const std::vector& eval_model_data = {}, + const std::vector& optim_model_data = {}); + /// @} + + /// \name Implementing The Training Loop + /// @{ + /** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs + * + * This function performs a training step that computes the outputs of the training model and the gradients + * of the trainable parameters for the given inputs. The train step is performed based on the training model + * that was provided to the training session. + * The Ort::TrainingSession::TrainStep is equivalent of running forward propagation and backward propagation in a single + * step. + * The gradients computed are stored inside the training session state so they can be later consumed + * by the Ort::TrainingSession::OptimizerStep function. + * The gradients can be lazily reset by invoking the Ort::TrainingSession::LazyResetGrad function. + * + * \param[in] input_values The user inputs to the training model. + * \return A std::vector of Ort::Value objects that represents the output of the forward pass of the training model. + * + * + */ + std::vector TrainStep(const std::vector& input_values); + + /** \brief Reset the gradients of all trainable parameters to zero lazily. + * + * This function sets the internal state of the training session such that the gradients of the trainable + * parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are + * computed on the next invocation of the next Ort::TrainingSession::TrainStep. + * + */ + void LazyResetGrad(); + + /** \brief Computes the outputs for the eval model for the given inputs + * + * This function performs an eval step that computes the outputs of the eval model for the given inputs. + * The eval step is performed based on the eval model that was provided to the training session. + * + * \param[in] input_values The user inputs to the eval model. + * \return A std::vector of Ort::Value objects that represents the output of the eval pass. + * + */ + std::vector EvalStep(const std::vector& input_values); + + /** \brief Sets the learning rate for this training session. + * + * This function allows users to set the learning rate for the training session. The current + * learning rate is maintained by the training session and can be overwritten by invoking + * this function with the desired learning rate. This function should not be used when a valid + * learning rate scheduler is registered. It should be used either to set the learning rate + * derived from a custom learning rate scheduler or to set a constant learning rate to be used + * throughout the training session. + * \note Please note that this function does not set the initial learning rate that may be needed + * by the predefined learning rate schedulers. To set the initial learning rate for learning + * rate schedulers, please look at the function Ort::TrainingSession::RegisterLinearLRScheduler. + * + * \param[in] learning_rate Desired learning rate to be set. + * + */ + void SetLearningRate(float learning_rate); + + /** \brief Gets the current learning rate for this training session. + * + * This function allows users to get the learning rate for the training session. The current + * learning rate is maintained by the training session, and users can query it for the purpose + * of implementing their own learning rate schedulers. + * + * \return float representing the current learning rate. + * + */ + float GetLearningRate() const; + + /** \brief Registers a linear learning rate scheduler for the training session. + * + * Register a linear learning rate scheduler that decays the learning rate by linearly updated + * multiplicative factor from the initial learning rate set on the training session to 0. The decay + * is performed after the initial warm up phase where the learning rate is linearly incremented + * from 0 to the initial learning rate provided. + * + * \param[in] warmup_step_count Warmup steps for LR warmup. + * \param[in] total_step_count Total step count. + * \param[in] initial_lr The initial learning rate to be used by the training session. + * + */ + void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count, + float initial_lr); + + /** \brief Update the learning rate based on the registered learing rate scheduler. + * + * Takes a scheduler step that updates the learning rate that is being used by the training session. + * This function should typically be called before invoking the optimizer step for each round, + * or as determined necessary to update the learning rate being used by the training session. + * \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this + * function. + * + */ + void SchedulerStep(); + + /** \brief Performs the weight updates for the trainable parameters using the optimizer model. + * + * This function performs the weight update step that updates the trainable parameters such that they + * take a step in the direction of their gradients (gradient descent). The optimizer step is performed + * based on the optimizer model that was provided to the training session. + * The updated parameters are stored inside the training state so that they can be used by the next + * Ort::TrainingSession::TrainStep function call. + * + */ + void OptimizerStep(); + + /// @} + + /// \name Prepare For Inferencing + /// @{ + + /** \brief Export a model that can be used for inferencing. + * + * If the training session was provided with an eval model, the training session can generate + * an inference model if it knows the inference graph outputs. The input inference graph outputs + * are used to prune the eval model so that the inference model's outputs align with the provided outputs. + * The exported model is saved at the path provided and can be used for inferencing with Ort::Session. + * \note Note that the function re-loads the eval model from the path provided to Ort::TrainingSession + * and expects that this path still be valid. + * + * \param[in] inference_model_path Path where the inference model should be serialized to. + * \param[in] graph_output_names Names of the outputs that are needed in the inference model. + * + */ + void ExportModelForInferencing(const std::basic_string& inference_model_path, + const std::vector& graph_output_names); + + /// @} + + /// \name Model IO Information + /// @{ + /** \brief Retrieves the names of the user inputs for the training and eval models. + * + * This function returns the names of inputs of the training or eval model that can be associated + * with the Ort::Value(s) provided to the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep + * function. + * + * \param[in] training Whether the training model input names are requested or eval model input names. + * \return Graph input names for either the training model or the eval model. + * + */ + std::vector InputNames(const bool training); + + /** \brief Retrieves the names of the user outputs for the training and eval models. + * + * This function returns the names of outputs of the training or eval model that can be associated + * with the Ort::Value(s) returned by the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep + * function. + * + * \param[in] training Whether the training model output names are requested or eval model output names. + * \return Graph output names for either the training model or the eval model. + * + */ + std::vector OutputNames(const bool training); + + /// @} + + /// \name Accessing The Training Session State + /// @{ + + /** \brief Returns a contiguous buffer that holds a copy of all training state parameters + * + * \param[in] only_trainable Whether to only copy trainable parameters or to copy all parameters. + * \return Contiguous buffer to the model parameters. + * + */ + Value ToBuffer(const bool only_trainable); + + /** \brief Loads the training session model parameters from a contiguous buffer + * + * In case the training session was created with a nominal checkpoint, invoking this function is required + * to load the updated parameters onto the checkpoint to complete it. + * + * \param[in] buffer Contiguous buffer to load the parameters from. + */ + void FromBuffer(Value& buffer); + + /// @} +}; + +/// \name Training Utilities +/// @{ +/** \brief This function sets the seed for generating random numbers. + * + * Use this function to generate reproducible results. It should be noted that completely + * reproducible results are not guaranteed. + * + * \param[in] seed Manual seed to use for random number generation. + */ +void SetSeed(const int64_t seed); +/// @} + +/// @} + +} // namespace Ort + +#include "onnxruntime_training_cxx_inline.h" diff --git a/libs/onnxruntime/include/onnxruntime_training_cxx_inline.h b/libs/onnxruntime/include/onnxruntime_training_cxx_inline.h new file mode 100644 index 0000000..397cba0 --- /dev/null +++ b/libs/onnxruntime/include/onnxruntime_training_cxx_inline.h @@ -0,0 +1,295 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "onnxruntime_training_c_api.h" +#include "onnxruntime_cxx_api.h" + +namespace Ort { + +inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options, + CheckpointState& checkpoint_state, + const std::basic_string& train_model_path, + const std::optional>& eval_model_path, + const std::optional>& optimizer_model_path) { + ThrowOnError(GetTrainingApi().CreateTrainingSession( + env, session_options, checkpoint_state, + train_model_path.c_str(), + eval_model_path.has_value() ? eval_model_path.value().c_str() : nullptr, + optimizer_model_path.has_value() ? optimizer_model_path.value().c_str() : nullptr, + &p_)); + + ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_)); + + ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); +} + +inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options, + CheckpointState& checkpoint_state, + const std::vector& train_model_data, + const std::vector& eval_model_data, + const std::vector& optim_model_data) { + ThrowOnError(GetTrainingApi().CreateTrainingSessionFromBuffer( + env, session_options, checkpoint_state, + train_model_data.data(), train_model_data.size(), + eval_model_data.data(), eval_model_data.size(), + optim_model_data.data(), optim_model_data.size(), + &p_)); + + ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_)); + + ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); +} + +inline std::vector TrainingSession::TrainStep(const std::vector& input_values) { + std::vector output_values; + output_values.reserve(training_model_output_count_); + for (size_t i = 0; i < training_model_output_count_; i++) output_values.emplace_back(nullptr); + auto ort_input_values = reinterpret_cast(input_values.data()); + auto ort_output_values = reinterpret_cast(output_values.data()); + RunOptions run_options; + ThrowOnError(GetTrainingApi().TrainStep( + p_, run_options, input_values.size(), ort_input_values, + training_model_output_count_, ort_output_values)); + + return output_values; +} + +inline void TrainingSession::LazyResetGrad() { + ThrowOnError(GetTrainingApi().LazyResetGrad(p_)); +} + +inline std::vector TrainingSession::EvalStep(const std::vector& input_values) { + std::vector output_values; + output_values.reserve(eval_model_output_count_); + for (size_t i = 0; i < eval_model_output_count_; i++) output_values.emplace_back(nullptr); + auto ort_input_values = reinterpret_cast(input_values.data()); + auto ort_output_values = reinterpret_cast(output_values.data()); + RunOptions run_options; + ThrowOnError(GetTrainingApi().EvalStep( + p_, run_options, input_values.size(), ort_input_values, + eval_model_output_count_, ort_output_values)); + + return output_values; +} + +inline void TrainingSession::SetLearningRate(float learning_rate) { + ThrowOnError(GetTrainingApi().SetLearningRate(p_, learning_rate)); +} + +inline float TrainingSession::GetLearningRate() const { + float learning_rate = 0; + ThrowOnError(GetTrainingApi().GetLearningRate(p_, &learning_rate)); + return learning_rate; +} + +inline void TrainingSession::RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count, + float initial_lr) { + ThrowOnError(GetTrainingApi().RegisterLinearLRScheduler(p_, warmup_step_count, total_step_count, + initial_lr)); +} + +inline void TrainingSession::SchedulerStep() { + ThrowOnError(GetTrainingApi().SchedulerStep(p_)); +} + +inline void TrainingSession::OptimizerStep() { + RunOptions run_options; + ThrowOnError(GetTrainingApi().OptimizerStep(p_, run_options)); +} + +inline std::vector TrainingSession::InputNames(const bool training) { + auto& input_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputCount + : GetTrainingApi().TrainingSessionGetEvalModelInputCount; + auto& input_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputName + : GetTrainingApi().TrainingSessionGetEvalModelInputName; + + size_t input_count = 0; + ThrowOnError(input_count_function(p_, &input_count)); + std::vector input_names(input_count); + AllocatorWithDefaultOptions allocator; + for (size_t index = 0; index < input_count; ++index) { + char* input_name; + ThrowOnError(input_name_function(p_, index, allocator, &input_name)); + input_names[index] = std::string(input_name); + allocator.Free(input_name); + } + + return input_names; +} + +inline std::vector TrainingSession::OutputNames(const bool training) { + auto& output_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputCount + : GetTrainingApi().TrainingSessionGetEvalModelOutputCount; + auto& output_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputName + : GetTrainingApi().TrainingSessionGetEvalModelOutputName; + + size_t output_count = 0; + ThrowOnError(output_count_function(p_, &output_count)); + std::vector output_names(output_count); + AllocatorWithDefaultOptions allocator; + for (size_t index = 0; index < output_count; ++index) { + char* output_name; + ThrowOnError(output_name_function(p_, index, allocator, &output_name)); + output_names[index] = std::string(output_name); + allocator.Free(output_name); + } + + return output_names; +} + +inline Value TrainingSession::ToBuffer(const bool only_trainable) { + size_t buffer_size = 0U; + ThrowOnError(GetTrainingApi().GetParametersSize(p_, &buffer_size, only_trainable)); + + std::array buffer_shape{static_cast(buffer_size)}; + + AllocatorWithDefaultOptions allocator; + Value buffer = Value::CreateTensor(allocator, buffer_shape.data(), 1U, + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + ThrowOnError(GetTrainingApi().CopyParametersToBuffer(p_, buffer, only_trainable)); + + return buffer; +} + +inline void TrainingSession::FromBuffer(Value& buffer) { + if (!buffer.IsTensor()) { + ThrowStatus(Status("Incorrect buffer received. Expected a tensor buffer.", OrtErrorCode::ORT_INVALID_ARGUMENT)); + } + + auto tensor_info = buffer.GetTensorTypeAndShapeInfo(); + auto buffer_shape = tensor_info.GetShape(); + + if (buffer_shape.size() != 1U) { + ThrowStatus(Status("Incorrect buffer received. Expected a contiguous tensor buffer.", + OrtErrorCode::ORT_INVALID_ARGUMENT)); + } + + auto buffer_size = buffer_shape.front(); + + size_t session_buffer_size = 0U; + ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false)); + + if (buffer_size == static_cast(session_buffer_size)) { + ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false)); + return; + } + + size_t session_buffer_size_trainable_only = 0U; + ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size_trainable_only, true)); + + if (buffer_size == static_cast(session_buffer_size_trainable_only)) { + ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, true)); + return; + } else { + ThrowStatus(Status("Incorrect buffer size received.", OrtErrorCode::ORT_INVALID_ARGUMENT)); + } +} + +inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string& path_to_checkpoint) { + OrtCheckpointState* checkpoint_state; + ThrowOnError(GetTrainingApi().LoadCheckpoint(path_to_checkpoint.c_str(), &checkpoint_state)); + return CheckpointState(checkpoint_state); +} + +inline CheckpointState CheckpointState::LoadCheckpointFromBuffer(const std::vector& buffer) { + OrtCheckpointState* checkpoint_state; + ThrowOnError(GetTrainingApi().LoadCheckpointFromBuffer(buffer.data(), buffer.size(), &checkpoint_state)); + return CheckpointState(checkpoint_state); +} + +inline void CheckpointState::SaveCheckpoint(const CheckpointState& checkpoint_states, + const std::basic_string& path_to_checkpoint, + const bool include_optimizer_state) { + ThrowOnError(GetTrainingApi().SaveCheckpoint(checkpoint_states, path_to_checkpoint.c_str(), + include_optimizer_state)); +} + +inline void TrainingSession::ExportModelForInferencing(const std::basic_string& inference_model_path, + const std::vector& graph_output_names) { + std::vector output_names; + output_names.reserve(graph_output_names.size()); + for (const auto& output_name : graph_output_names) { + output_names.push_back(output_name.c_str()); + } + ThrowOnError(GetTrainingApi().ExportModelForInferencing( + p_, inference_model_path.c_str(), graph_output_names.size(), output_names.data())); +} + +inline void SetSeed(const int64_t seed) { + ThrowOnError(GetTrainingApi().SetSeed(seed)); +} + +inline void CheckpointState::AddProperty(const std::string& property_name, const Property& property_value) { + if (std::holds_alternative(property_value)) { + int64_t value = std::get(property_value); + void* value_p = &value; + ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtIntProperty, value_p)); + } else if (std::holds_alternative(property_value)) { + float value = std::get(property_value); + void* value_p = &value; + ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtFloatProperty, value_p)); + } else if (std::holds_alternative(property_value)) { + std::string value = std::get(property_value); + auto buffer = std::make_unique(value.length() + 1); + memcpy(buffer.get(), value.c_str(), value.length()); + // AddProperty takes a char* and calls PropertyBag::AddProperty which takes a std::string. The data will be + // copied at that point so buffer can free the local allocation once the call is made. + ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtStringProperty, + buffer.get())); + } else { + ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT)); + } +} + +inline Property CheckpointState::GetProperty(const std::string& property_name) { + void* property_value = nullptr; + OrtPropertyType property_type; + + AllocatorWithDefaultOptions allocator; + ThrowOnError(GetTrainingApi().GetProperty(p_, property_name.c_str(), allocator, &property_type, &property_value)); + + Property property; + + switch (property_type) { + case OrtPropertyType::OrtIntProperty: { + auto value_p = reinterpret_cast(property_value); + property = *value_p; + allocator.Free(property_value); + break; + } + case OrtPropertyType::OrtFloatProperty: { + auto value_p = reinterpret_cast(property_value); + property = *value_p; + allocator.Free(property_value); + break; + } + case OrtPropertyType::OrtStringProperty: { + auto value_p = reinterpret_cast(property_value); + property = std::string(value_p); + allocator.Free(property_value); + break; + } + default: { + ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT)); + break; + } + } + + return property; +} + +inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) { + ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter)); +} + +inline Value CheckpointState::GetParameter(const std::string& parameter_name) { + AllocatorWithDefaultOptions allocator; + OrtValue* parameter; + ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter)); + + return Value{parameter}; +} + +} // namespace Ort diff --git a/libs/onnxruntime/include/tensorrt_provider_factory.h b/libs/onnxruntime/include/tensorrt_provider_factory.h deleted file mode 100644 index ffbd170..0000000 --- a/libs/onnxruntime/include/tensorrt_provider_factory.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "onnxruntime_c_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); - -#ifdef __cplusplus -} -#endif diff --git a/libs/onnxruntime/lib/osx/libonnxruntime.1.10.0.dylib b/libs/onnxruntime/lib/osx/libonnxruntime.1.10.0.dylib deleted file mode 100755 index c4b6ac8..0000000 Binary files a/libs/onnxruntime/lib/osx/libonnxruntime.1.10.0.dylib and /dev/null differ diff --git a/src/ofxOnnxRuntime.cpp b/src/ofxOnnxRuntime.cpp index 8eae86e..cd82f69 100644 --- a/src/ofxOnnxRuntime.cpp +++ b/src/ofxOnnxRuntime.cpp @@ -19,21 +19,17 @@ namespace ofxOnnxRuntime void BaseHandler::setup(const std::string & onnx_path, const BaseSetting & base_setting) { Ort::SessionOptions session_options; - if (base_setting.infer_type == INFER_TENSORRT) { - OrtTensorRTProviderOptions op; - memset(&op, 0, sizeof(op)); - op.device_id = base_setting.device_id; - op.trt_fp16_enable = 1; - op.trt_engine_cache_enable = 1; - std::string path = ofToDataPath(onnx_path, true); - ofStringReplace(path, ".onnx", "_trt_cache"); - op.trt_engine_cache_path = path.c_str(); - session_options.AppendExecutionProvider_TensorRT(op); - } - if (base_setting.infer_type == INFER_CUDA || base_setting.infer_type == INFER_TENSORRT) { - OrtCUDAProviderOptions op; - op.device_id = base_setting.device_id; - session_options.AppendExecutionProvider_CUDA(op); + session_options.SetIntraOpNumThreads(1); + session_options.SetIntraOpNumThreads(1); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); + + if (base_setting.infer_type == INFER_CUDA) { + OrtCUDAProviderOptions opts; + opts.device_id = 0; + opts.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchExhaustive; + opts.do_copy_in_default_stream = 0; + opts.arena_extend_strategy = 0; + session_options.AppendExecutionProvider_CUDA(opts); } this->setup2(onnx_path, session_options); } @@ -49,49 +45,86 @@ namespace ofxOnnxRuntime Ort::AllocatorWithDefaultOptions allocator; - // 2. input name & input dims - auto* input_name = ort_session->GetInputName(0, allocator); - input_node_names.resize(1); - input_node_names[0] = input_name; - - // 3. type info. - Ort::TypeInfo type_info = ort_session->GetInputTypeInfo(0); - auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); - input_tensor_size = 1; - input_node_dims = tensor_info.GetShape(); - for (unsigned int i = 0; i < input_node_dims.size(); ++i) - input_tensor_size *= input_node_dims.at(i); - input_values_handler.resize(input_tensor_size); - - // 4. output names & output dimms - num_outputs = ort_session->GetOutputCount(); - output_node_names.resize(num_outputs); + // 1. Gets Input Name/s & Shape ([1, 3, 28, 28]) -- In most cases this is usually just one + for (std::size_t i = 0; i < ort_session->GetInputCount(); i++) { + input_node_names.emplace_back(ort_session->GetInputNameAllocated(i, allocator).get()); + input_node_dims = ort_session->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape(); + + // Some models might have negative shape values to indicate dynamic shape, e.g., for variable batch size. (?, 3, 28, 28) -> (1, 3, 28, 28) + for (auto& s : input_node_dims) if (s < 0) s = 1; + + std::cout << input_node_names.at(i) << " : " << PrintShape(input_node_dims) << std::endl; + } + + // 2. Clear up output values output_node_dims.clear(); output_values.clear(); - for (unsigned int i = 0; i < num_outputs; ++i) - { - output_node_names[i] = ort_session->GetOutputName(i, allocator); - Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i); - auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo(); - auto output_dims = output_tensor_info.GetShape(); - output_node_dims.emplace_back(output_dims); + + // 3. Gets Output name/s & Shapes + for (std::size_t i = 0; i < ort_session->GetOutputCount(); i++) { + output_node_names.emplace_back(ort_session->GetOutputNameAllocated(i, allocator).get()); + auto output_shapes = ort_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape(); + + for (auto& s : output_shapes) if (s < 0) s = 1; + output_values.emplace_back(nullptr); + + std::cout << output_node_names.at(i) << " : " << PrintShape(output_shapes) << std::endl; } } Ort::Value& BaseHandler::run() { - auto input_tensor_ = Ort::Value::CreateTensor( - memory_info_handler, input_values_handler.data(), input_tensor_size, - input_node_dims.data(), input_node_dims.size()); - ort_session->Run(Ort::RunOptions{ nullptr }, input_node_names.data(), &input_tensor_, input_node_names.size(), - output_node_names.data(), output_values.data(), output_node_names.size()); + std::vector input_tensors; + + input_tensors.emplace_back(GenerateTensor()); + + // transform std::string -> const char* + std::vector input_names_char(input_node_names.size(), nullptr); + std::transform(std::begin(input_node_names), std::end(input_node_names), std::begin(input_names_char), + [&](const std::string& str) { return str.c_str(); }); - if (output_values.size() == 1) { + std::vector output_names_char(output_node_names.size(), nullptr); + std::transform(std::begin(output_node_names), std::end(output_node_names), std::begin(output_names_char), + [&](const std::string& str) { return str.c_str(); }); + + + try { + output_values = ort_session->Run(Ort::RunOptions{ nullptr }, input_names_char.data(), input_tensors.data(), + input_names_char.size(), output_names_char.data(), output_names_char.size()); + std::cout << "Success!" << std::endl; return output_values.at(0); } - else { - return dummy_tensor; + catch (const Ort::Exception& ex) { + std::cout << "ERROR running model inference: " << ex.what() << std::endl; + return dummy_output_tensor.at(0); } + + } + + // Prints the shape of the given tensor (ex. input: (1, 1, 512, 512)) + std::string BaseHandler::PrintShape(const std::vector& v) { + std::stringstream ss; + for (std::size_t i = 0; i < v.size() - 1; i++) ss << v[i] << "x"; + ss << v[v.size() - 1]; + return ss.str(); + } + + Ort::Value BaseHandler::GenerateTensor() { + std::vector random_input_tensor_values(CalculateProduct(input_node_dims)); + std::generate(random_input_tensor_values.begin(), random_input_tensor_values.end(), [&] { return rand() % 255; }); + return VectorToTensor(random_input_tensor_values, input_node_dims); + } + + int BaseHandler::CalculateProduct(const std::vector& v) { + int total = 1; + for (auto& i : v) total *= i; + return total; + } + + Ort::Value BaseHandler::VectorToTensor(std::vector& data, const std::vector& shape) { + Ort::MemoryInfo mem_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); + auto tensor = Ort::Value::CreateTensor(mem_info, data.data(), data.size(), shape.data(), shape.size()); + return tensor; } } diff --git a/src/ofxOnnxRuntime.h b/src/ofxOnnxRuntime.h index fefc7ae..ab56941 100644 --- a/src/ofxOnnxRuntime.h +++ b/src/ofxOnnxRuntime.h @@ -27,21 +27,29 @@ namespace ofxOnnxRuntime Ort::Value& run(); - float* getInputTensorData() { - return this->input_values_handler.data(); - } + // Utilities + std::string PrintShape(const std::vector& v); + Ort::Value GenerateTensor(); + int CalculateProduct(const std::vector& v); + Ort::Value VectorToTensor(std::vector& data, const std::vector& shape); + protected: Ort::Env ort_env; std::shared_ptr ort_session; - std::vector input_node_names; + + std::vector input_node_names; std::vector input_node_dims; // 1 input only. std::size_t input_tensor_size = 1; - std::vector input_values_handler; + Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - std::vector output_node_names; + + std::vector output_node_names; std::vector> output_node_dims; // >=1 outputs std::vector output_values; + Ort::Value dummy_tensor{ nullptr }; + std::vector dummy_output_tensor; + int num_outputs = 1; }; }