// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // Summary: The Ort C++ API is a header only wrapper around the Ort C API. // // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors // and automatically releasing resources in the destructors. // // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers. // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). // // Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone' // methods for this purpose. #pragma once #include "onnxruntime_c_api.h" #include #include #include #include #include #include #include #include #ifdef ORT_NO_EXCEPTIONS #include #endif /** \brief All C++ Onnxruntime APIs are defined inside this namespace * */ namespace Ort { /** \brief All C++ methods that can fail will throw an exception of this type * * If ORT_NO_EXCEPTIONS is defined, then any error will result in a call to abort() */ struct Exception : std::exception { Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} OrtErrorCode GetOrtErrorCode() const { return code_; } const char* what() const noexcept override { return message_.c_str(); } private: std::string message_; OrtErrorCode code_; }; #ifdef ORT_NO_EXCEPTIONS #define ORT_CXX_API_THROW(string, code) \ do { \ std::cerr << Ort::Exception(string, code) \ .what() \ << std::endl; \ abort(); \ } while (false) #else #define ORT_CXX_API_THROW(string, code) \ throw Ort::Exception(string, code) #endif // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, it's in a template so that we can define a global variable in a header and make // it transparent to the users of the API. template struct Global { static const OrtApi* api_; }; // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. template #ifdef ORT_API_MANUAL_INIT const OrtApi* Global::api_{}; inline void InitApi() { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } #else const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); #endif /// This returns a reference to the OrtApi interface in use inline const OrtApi& GetApi() { return *Global::api_; } /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and returns a vector of strings representing the available execution providers. std::vector GetAvailableProviders(); // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type // This can't be done in the C API since C doesn't have function overloading. #define ORT_DEFINE_RELEASE(NAME) \ inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); } ORT_DEFINE_RELEASE(Allocator); ORT_DEFINE_RELEASE(MemoryInfo); ORT_DEFINE_RELEASE(CustomOpDomain); ORT_DEFINE_RELEASE(Env); ORT_DEFINE_RELEASE(RunOptions); ORT_DEFINE_RELEASE(Session); ORT_DEFINE_RELEASE(SessionOptions); ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); ORT_DEFINE_RELEASE(SequenceTypeInfo); ORT_DEFINE_RELEASE(MapTypeInfo); ORT_DEFINE_RELEASE(TypeInfo); ORT_DEFINE_RELEASE(Value); ORT_DEFINE_RELEASE(ModelMetadata); ORT_DEFINE_RELEASE(ThreadingOptions); ORT_DEFINE_RELEASE(IoBinding); ORT_DEFINE_RELEASE(ArenaCfg); #undef ORT_DEFINE_RELEASE /** \brief IEEE 754 half-precision floating point data type * \details It is necessary for type dispatching to make use of C++ API * The type is implicitly convertible to/from uint16_t. * The size of the structure should align with uint16_t and one can freely cast * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. * * Generally, you can feed any of your types as float16/blfoat16 data to create a tensor * on top of it, providing it can form a continuous buffer with 16-bit elements with no padding. * And you can also feed a array of uint16_t elements directly. For example, * * \code{.unparsed} * uint16_t values[] = { 15360, 16384, 16896, 17408, 17664}; * constexpr size_t values_length = sizeof(values) / sizeof(values[0]); * std::vector dims = {values_length}; // one dimensional example * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); * // Note we are passing bytes count in this api, not number of elements -> sizeof(values) * auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values), * dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16); * \endcode * * Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use * a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra * template specialization. * * \code{.unparsed} * namespace yours { struct half {}; } // assume this is your type, define this: * namespace Ort { * template<> * struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; * } //namespace Ort * * std::vector values; * std::vector dims = {values.size()}; // one dimensional example * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); * // Here we are passing element count -> values.size() * auto float16_tensor = Ort::Value::CreateTensor(info, values.data(), values.size(), dims.data(), dims.size()); * * \endcode */ struct Float16_t { uint16_t value; constexpr Float16_t() noexcept : value(0) {} constexpr Float16_t(uint16_t v) noexcept : value(v) {} constexpr operator uint16_t() const noexcept { return value; } constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; }; constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; }; }; static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); /** \brief bfloat16 (Brain Floating Point) data type * \details It is necessary for type dispatching to make use of C++ API * The type is implicitly convertible to/from uint16_t. * The size of the structure should align with uint16_t and one can freely cast * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. * * See also code examples for Float16_t above. */ struct BFloat16_t { uint16_t value; constexpr BFloat16_t() noexcept : value(0) {} constexpr BFloat16_t(uint16_t v) noexcept : value(v) {} constexpr operator uint16_t() const noexcept { return value; } constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; }; constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; }; }; static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); /** \brief Used internally by the C++ API. C++ wrapper types inherit from this * * This is a zero cost abstraction to wrap the C API objects and delete them on destruction. * There is a secondary class 'Unowned' that is used to prevent deletion on destruction (Used for return types that are * not owned by the caller) * */ template struct Base { using contained_type = T; Base() = default; Base(T* p) : p_{p} { if (!p) ORT_CXX_API_THROW("Allocation failure", ORT_FAIL); } ~Base() { OrtRelease(p_); } operator T*() { return p_; } operator const T*() const { return p_; } /// \brief Releases ownership of the contained pointer T* release() { T* p = p_; p_ = nullptr; return p; } protected: Base(const Base&) = delete; Base& operator=(const Base&) = delete; Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } void operator=(Base&& v) noexcept { OrtRelease(p_); p_ = v.release(); } T* p_{}; template friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error }; /** \brief Wraps an object that inherits from Ort::Base and stops it from deleting the contained pointer on destruction * * This has the effect of making it not own the memory held by Ort::Base. */ template struct Unowned : T { Unowned(typename T::contained_type* p) : T{p} {} Unowned(Unowned&& v) : T{v.p_} {} ~Unowned() { this->release(); } }; struct AllocatorWithDefaultOptions; struct MemoryInfo; struct Env; struct TypeInfo; struct Value; struct ModelMetadata; /** \brief The Env (Environment) * * The Env holds the logging state used by all other objects. * Note: One Env must be created before using any other Onnxruntime functionality */ struct Env : Base { explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used /// \brief Wraps OrtApi::CreateEnv Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); /// \brief Wraps OrtApi::CreateEnvWithCustomLogger Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param); /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); /// \brief C Interop Helper explicit Env(OrtEnv* p) : Base{p} {} Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator }; /** \brief Custom Op Domain * */ struct CustomOpDomain : Base { explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used /// \brief Wraps OrtApi::CreateCustomOpDomain explicit CustomOpDomain(const char* domain); void Add(OrtCustomOp* op); ///< Wraps CustomOpDomain_Add }; struct RunOptions : Base { explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used RunOptions(); ///< Wraps OrtApi::CreateRunOptions RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance * * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error * Wraps OrtApi::RunOptionsSetTerminate */ RunOptions& SetTerminate(); /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating * * Wraps OrtApi::RunOptionsUnsetTerminate */ RunOptions& UnsetTerminate(); }; /** \brief Options object used when creating a new Session object * * Wraps ::OrtSessionOptions object and methods */ struct SessionOptions : Base { explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions explicit SessionOptions(OrtSessionOptions* p) : Base{p} {} ///< Used for interop with the C API SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads SessionOptions& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel SessionOptions& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena SessionOptions& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling SessionOptions& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling SessionOptions& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps SessionOptions& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern SessionOptions& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern SessionOptions& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode SessionOptions& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId SessionOptions& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel SessionOptions& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain SessionOptions& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads SessionOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn SessionOptions& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions SessionOptions& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn }; /** \brief Wrapper around ::OrtModelMetadata * */ struct ModelMetadata : Base { explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used explicit ModelMetadata(OrtModelMetadata* p) : Base{p} {} ///< Used for interop with the C API char* GetProducerName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName char* GetGraphName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName char* GetDomain(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain char* GetDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription char* GetGraphDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion }; /** \brief Wrapper around ::OrtSession * */ struct Session : Base { explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray /** \brief Run the model returning results in an Ort allocated vector. * * Wraps OrtApi::Run * * The caller provides a list of inputs and a list of the desired outputs to return. * * See the output logs for more information on warnings/errors that occur while processing the model. * Common errors are.. (TODO) * * \param[in] run_options * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names * \param[in] input_values Array of Value objects of length input_count that is the list of input values * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays) * \param[in] output_names Array of C style strings of length output_count that is the list of output names * \param[in] output_count Number of outputs (the size of the output_names array) * \return A std::vector of Value objects that directly maps to the output_count (eg. output_name[0] is the first entry of the returned vector) */ std::vector Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, const char* const* output_names, size_t output_count); /** \brief Run the model returning results in user provided outputs * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t) */ void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, const char* const* output_names, Value* output_values, size_t output_count); void Run(const RunOptions& run_options, const struct IoBinding&); ///< Wraps OrtApi::RunWithBinding size_t GetInputCount() const; ///< Returns the number of model inputs size_t GetOutputCount() const; ///< Returns the number of model outputs size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden char* GetInputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetInputName char* GetOutputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOutputName char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName char* EndProfiling(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo }; /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo * */ struct TensorTypeAndShapeInfo : Base { explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : Base{p} {} ///< Used for interop with the C API ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape }; /** \brief Wrapper around ::OrtSequenceTypeInfo * */ struct SequenceTypeInfo : Base { explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : Base{p} {} ///< Used for interop with the C API TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType }; /** \brief Wrapper around ::OrtMapTypeInfo * */ struct MapTypeInfo : Base { explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used explicit MapTypeInfo(OrtMapTypeInfo* p) : Base{p} {} ///< Used for interop with the C API ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType }; struct TypeInfo : Base { explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used explicit TypeInfo(OrtTypeInfo* p) : Base{p} {} ///< C API Interop Unowned GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo Unowned GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo Unowned GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo ONNXType GetONNXType() const; }; struct Value : Base { // This structure is used to feed sparse tensor values // information for use with FillSparseTensor() API // if the data type for the sparse tensor values is numeric // use data.p_data, otherwise, use data.str pointer to feed // values. data.str is an array of const char* that are zero terminated. // number of strings in the array must match shape size. // For fully sparse tensors use shape {0} and set p_data/str // to nullptr. struct OrtSparseValuesParam { const int64_t* values_shape; size_t values_shape_len; union { const void* p_data; const char** str; } data; }; // Provides a way to pass shape in a single // argument struct Shape { const int64_t* shape; size_t shape_len; }; /// \brief Wraps OrtApi::CreateTensorWithDataAsOrtValue template static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); /// \brief Wraps OrtApi::CreateTensorWithDataAsOrtValue static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); #if !defined(DISABLE_SPARSE_TENSORS) /// /// This is a simple forwarding method to the other overload that helps deducing /// data type enum value from the type of the buffer. /// /// numeric datatype. This API is not suitable for strings. /// Memory description where the user buffers reside (CPU vs GPU etc) /// pointer to the user supplied buffer, use nullptr for fully sparse tensors /// a would be dense shape of the tensor /// non zero values shape. Use a single 0 shape for fully sparse tensors. /// template static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, const Shape& values_shape); /// /// Creates an OrtValue instance containing SparseTensor. This constructs /// a sparse tensor that makes use of user allocated buffers. It does not make copies /// of the user provided data and does not modify it. The lifespan of user provided buffers should /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain /// a pointer to non-zero values. To fully populate the sparse tensor call UseIndices() API below /// to supply a sparse format specific indices. /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings /// can be properly copied into the allocated buffer. /// /// Memory description where the user buffers reside (CPU vs GPU etc) /// pointer to the user supplied buffer, use nullptr for fully sparse tensors /// a would be dense shape of the tensor /// non zero values shape. Use a single 0 shape for fully sparse tensors. /// data type /// Ort::Value instance containing SparseTensor static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, const Shape& values_shape, ONNXTensorElementDataType type); /// /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor. /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user /// allocated buffers lifespan must eclipse that of the OrtValue. /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. /// /// pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors. /// number of indices entries. Use 0 for fully sparse tensors void UseCooIndices(int64_t* indices_data, size_t indices_num); /// /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor. /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user /// allocated buffers lifespan must eclipse that of the OrtValue. /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. /// /// pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors /// number of csr inner indices or 0 for fully sparse tensors /// pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors /// number of csr outer indices or 0 for fully sparse tensors void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num); /// /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor. /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user /// allocated buffers lifespan must eclipse that of the OrtValue. /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. /// /// indices shape or a {0} for fully sparse /// user allocated buffer with indices or nullptr for fully spare tensors void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data); #endif // !defined(DISABLE_SPARSE_TENSORS) // \brief Wraps OrtApi::CreateTensorAsOrtValue template static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len); // \brief Wraps OrtApi::CreateTensorAsOrtValue static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); #if !defined(DISABLE_SPARSE_TENSORS) /// /// This is a simple forwarding method the below CreateSparseTensor. /// This helps to specify data type enum in terms of C++ data type. /// Use CreateSparseTensor /// /// numeric data type only. String data enum must be specified explicitly. /// allocator to use /// a would be dense shape of the tensor /// Ort::Value template static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape); /// /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data. /// The data must be supplied by on of the FillSparseTensor() methods that take both non-zero values /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator. /// Use this API to create OrtValues that contain sparse tensors with all supported data types including /// strings. /// /// allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue /// a would be dense shape of the tensor /// data type /// an instance of Ort::Value static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type); /// /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located /// at difference device than the allocator, a X-device copy will be performed if possible. /// /// specified buffer memory description /// values buffer information. /// coo indices buffer or nullptr for fully sparse data /// number of COO indices or 0 for fully sparse data void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param, const int64_t* indices_data, size_t indices_num); /// /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located /// at difference device than the allocator, a X-device copy will be performed if possible. /// /// specified buffer memory description /// values buffer information /// csr inner indices pointer or nullptr for fully sparse tensors /// number of csr inner indices or 0 for fully sparse tensors /// pointer to csr indices data or nullptr for fully sparse tensors /// number of csr outer indices or 0 void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values, const int64_t* inner_indices_data, size_t inner_indices_num, const int64_t* outer_indices_data, size_t outer_indices_num); /// /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located /// at difference device than the allocator, a X-device copy will be performed if possible. /// /// specified buffer memory description /// values buffer information /// indices shape. use {0} for fully sparse tensors /// pointer to indices data or nullptr for fully sparse tensors void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values, const Shape& indices_shape, const int32_t* indices_data); /// /// The API returns the sparse data format this OrtValue holds in a sparse tensor. /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used /// the value returned is ORT_SPARSE_UNDEFINED. /// /// Format enum OrtSparseFormat GetSparseFormat() const; /// /// The API returns type and shape information for stored non-zero values of the /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer. /// /// TensorTypeAndShapeInfo values information TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const; /// /// The API returns type and shape information for the specified indices. Each supported /// indices have their own enum values even if a give format has more than one kind of indices. /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer. /// /// enum requested /// type and shape information TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const; /// /// The API retrieves a pointer to the internal indices buffer. The API merely performs /// a convenience data type casting on the return type pointer. Make sure you are requesting /// the right type, use GetSparseTensorIndicesTypeShapeInfo(); /// /// type to cast to /// requested indices kind /// number of indices entries /// Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer. template const T* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const; #endif // !defined(DISABLE_SPARSE_TENSORS) static Value CreateMap(Value& keys, Value& values); ///< Wraps OrtApi::CreateValue static Value CreateSequence(std::vector& values); ///< Wraps OrtApi::CreateValue template static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue template void GetOpaqueData(const char* domain, const char* type_name, T&) const; ///< Wraps OrtApi::GetOpaqueValue explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API Value(Value&&) = default; Value& operator=(Value&&) = default; bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None #if !defined(DISABLE_SPARSE_TENSORS) /// /// Returns true if the OrtValue contains a sparse tensor /// /// bool IsSparseTensor() const; #endif size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements Value GetValue(int index, OrtAllocator* allocator) const; /// /// This API returns a full length of string data contained within either a tensor or a sparse Tensor. /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful /// for allocating necessary memory and calling GetStringTensorContent(). /// /// total length of UTF-8 encoded bytes contained. No zero terminators counted. size_t GetStringTensorDataLength() const; /// /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate. /// The user must also allocate offsets buffer with the number of entries equal to that of the contained /// strings. /// /// Strings are always assumed to be on CPU, no X-device copy. /// /// user allocated buffer /// length in bytes of the allocated buffer /// a pointer to the offsets user allocated buffer /// count of offsets, must be equal to the number of strings contained. /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo() /// for sparse tensors void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; template T* GetTensorMutableData(); ///< Wraps OrtApi::GetTensorMutableData template const T* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData #if !defined(DISABLE_SPARSE_TENSORS) /// /// The API returns a pointer to an internal buffer of the sparse tensor /// containing non-zero values. The API merely does casting. Make sure you /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo() /// first. /// /// numeric data types only. Use GetStringTensor*() to retrieve strings. /// a pointer to the internal values buffer. Do not free this pointer. template const T* GetSparseTensorValues() const; #endif template T& At(const std::vector& location); /// /// The API returns type information for data contained in a tensor. For sparse /// tensors it returns type information for contained non-zero values. /// It returns dense shape for sparse tensors. /// /// TypeInfo TypeInfo GetTypeInfo() const; /// /// The API returns type information for data contained in a tensor. For sparse /// tensors it returns type information for contained non-zero values. /// It returns dense shape for sparse tensors. /// /// TensorTypeAndShapeInfo TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; /// /// The API returns a byte length of UTF-8 encoded string element /// contained in either a tensor or a spare tensor values. /// /// /// byte length for the specified string element size_t GetStringTensorElementLength(size_t element_index) const; /// /// The API copies UTF-8 encoded bytes for the requested string element /// contained within a tensor or a sparse tensor into a provided buffer. /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate. /// /// /// /// void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const; void FillStringTensor(const char* const* s, size_t s_len); void FillStringTensorElement(const char* s, size_t index); }; // Represents native memory allocation struct MemoryAllocation { MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); ~MemoryAllocation(); MemoryAllocation(const MemoryAllocation&) = delete; MemoryAllocation& operator=(const MemoryAllocation&) = delete; MemoryAllocation(MemoryAllocation&&) noexcept; MemoryAllocation& operator=(MemoryAllocation&&) noexcept; void* get() { return p_; } size_t size() const { return size_; } private: OrtAllocator* allocator_; void* p_; size_t size_; }; struct AllocatorWithDefaultOptions { AllocatorWithDefaultOptions(); operator OrtAllocator*() { return p_; } operator const OrtAllocator*() const { return p_; } void* Alloc(size_t size); // The return value will own the allocation MemoryAllocation GetAllocation(size_t size); void Free(void* p); const OrtMemoryInfo* GetInfo() const; private: OrtAllocator* p_{}; }; struct MemoryInfo : Base { static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); explicit MemoryInfo(std::nullptr_t) {} explicit MemoryInfo(OrtMemoryInfo* p) : Base{p} {} ///< Used for interop with the C API MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); std::string GetAllocatorName() const; OrtAllocatorType GetAllocatorType() const; int GetDeviceId() const; OrtMemType GetMemoryType() const; bool operator==(const MemoryInfo& o) const; }; struct Allocator : public Base { Allocator(const Session& session, const MemoryInfo&); void* Alloc(size_t size) const; // The return value will own the allocation MemoryAllocation GetAllocation(size_t size); void Free(void* p) const; Unowned GetInfo() const; }; struct IoBinding : public Base { explicit IoBinding(Session& session); void BindInput(const char* name, const Value&); void BindOutput(const char* name, const Value&); void BindOutput(const char* name, const MemoryInfo&); std::vector GetOutputNames() const; std::vector GetOutputNames(Allocator&) const; std::vector GetOutputValues() const; std::vector GetOutputValues(Allocator&) const; void ClearBoundInputs(); void ClearBoundOutputs(); void SynchronizeInputs(); void SynchronizeOutputs(); private: std::vector GetOutputNamesHelper(OrtAllocator*) const; std::vector GetOutputValuesHelper(OrtAllocator*) const; }; /*! \struct Ort::ArenaCfg * \brief it is a structure that represents the configuration of an arena based allocator * \details Please see docs/C_API.md for details */ struct ArenaCfg : Base { explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used /** * Wraps OrtApi::CreateArenaCfg * \param max_mem - use 0 to allow ORT to choose the default * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default * See docs/C_API.md for details on what the following parameters mean and how to choose these values */ ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); }; // // Custom OPs (only needed to implement custom OPs) // struct CustomOpApi { CustomOpApi(const OrtApi& api) : api_(api) {} template // T is only implemented for std::vector, std::vector, float, int64_t, and string T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name); OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value); size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info); ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info); size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info); void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); template T* GetTensorMutableData(_Inout_ OrtValue* value); template const T* GetTensorData(_Inout_ const OrtValue* value); const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value); std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info); void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input); size_t KernelContext_GetInputCount(const OrtKernelContext* context); const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index); size_t KernelContext_GetOutputCount(const OrtKernelContext* context); OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count); void* KernelContext_GetGPUComputeStream(const OrtKernelContext* context); void ThrowOnError(OrtStatus* result); private: const OrtApi& api_; }; template struct CustomOpBase : OrtCustomOp { CustomOpBase() { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast(this_)->GetName(); }; OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast(this_)->GetExecutionProviderType(); }; OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetInputTypeCount(); }; OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputType(index); }; OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetOutputTypeCount(); }; OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputType(index); }; OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast(op_kernel)->Compute(context); }; OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast(op_kernel); }; OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputCharacteristic(index); }; OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputCharacteristic(index); }; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider const char* GetExecutionProviderType() const { return nullptr; } // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below // (inputs and outputs are required by default) OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; } OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; } }; } // namespace Ort #include "onnxruntime_cxx_inline.h"