You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1193 lines
43 KiB
1193 lines
43 KiB
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
// Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
|
|
// If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
|
|
//
|
|
// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
|
|
// the main C++ file with implementation details.
|
|
|
|
namespace Ort {
|
|
|
|
inline void ThrowOnError(const OrtApi& ort, OrtStatus* status) {
|
|
if (status) {
|
|
std::string error_message = ort.GetErrorMessage(status);
|
|
OrtErrorCode error_code = ort.GetErrorCode(status);
|
|
ort.ReleaseStatus(status);
|
|
ORT_CXX_API_THROW(std::move(error_message), error_code);
|
|
}
|
|
}
|
|
|
|
inline void ThrowOnError(OrtStatus* status) {
|
|
ThrowOnError(GetApi(), status);
|
|
}
|
|
|
|
// This template converts a C++ type into it's ONNXTensorElementDataType
|
|
template <typename T>
|
|
struct TypeToTensorType;
|
|
template <>
|
|
struct TypeToTensorType<float> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
|
template <>
|
|
struct TypeToTensorType<Float16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
|
|
template <>
|
|
struct TypeToTensorType<BFloat16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; };
|
|
template <>
|
|
struct TypeToTensorType<double> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; };
|
|
template <>
|
|
struct TypeToTensorType<int8_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; };
|
|
template <>
|
|
struct TypeToTensorType<int16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; };
|
|
template <>
|
|
struct TypeToTensorType<int32_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; };
|
|
template <>
|
|
struct TypeToTensorType<int64_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; };
|
|
template <>
|
|
struct TypeToTensorType<uint8_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; };
|
|
template <>
|
|
struct TypeToTensorType<uint16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; };
|
|
template <>
|
|
struct TypeToTensorType<uint32_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; };
|
|
template <>
|
|
struct TypeToTensorType<uint64_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; };
|
|
template <>
|
|
struct TypeToTensorType<bool> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; };
|
|
|
|
inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
|
|
: allocator_(allocator), p_(p), size_(size) {
|
|
}
|
|
|
|
inline MemoryAllocation::~MemoryAllocation() {
|
|
if (p_ != nullptr) {
|
|
// We do not throw out of destructor
|
|
auto ret = GetApi().AllocatorFree(allocator_, p_);
|
|
static_cast<void>(ret);
|
|
}
|
|
}
|
|
|
|
inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
|
|
*this = std::move(o);
|
|
}
|
|
|
|
inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
|
|
OrtAllocator* alloc = nullptr;
|
|
void* p = nullptr;
|
|
size_t sz = 0;
|
|
|
|
// Swap out this
|
|
std::swap(alloc, allocator_);
|
|
std::swap(p, p_);
|
|
std::swap(sz, size_);
|
|
|
|
// Swap with incoming
|
|
std::swap(allocator_, o.allocator_);
|
|
std::swap(p_, o.p_);
|
|
std::swap(size_, o.size_);
|
|
|
|
// Destroy this instance if needed
|
|
MemoryAllocation this_alloc(alloc, p, sz);
|
|
return *this;
|
|
}
|
|
|
|
inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
|
|
ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&p_));
|
|
}
|
|
|
|
inline void* AllocatorWithDefaultOptions::Alloc(size_t size) {
|
|
void* out;
|
|
ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out));
|
|
return out;
|
|
}
|
|
|
|
inline MemoryAllocation Ort::AllocatorWithDefaultOptions::GetAllocation(size_t size) {
|
|
void* out;
|
|
ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out));
|
|
MemoryAllocation result(p_, out, size);
|
|
return result;
|
|
}
|
|
|
|
inline void AllocatorWithDefaultOptions::Free(void* p) {
|
|
ThrowOnError(GetApi().AllocatorFree(p_, p));
|
|
}
|
|
|
|
inline const OrtMemoryInfo* AllocatorWithDefaultOptions::GetInfo() const {
|
|
const OrtMemoryInfo* out;
|
|
ThrowOnError(GetApi().AllocatorGetInfo(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline std::string MemoryInfo::GetAllocatorName() const {
|
|
const char* name = nullptr;
|
|
ThrowOnError(GetApi().MemoryInfoGetName(*this, &name));
|
|
return std::string(name);
|
|
}
|
|
|
|
inline OrtAllocatorType MemoryInfo::GetAllocatorType() const {
|
|
OrtAllocatorType type;
|
|
ThrowOnError(GetApi().MemoryInfoGetType(*this, &type));
|
|
return type;
|
|
}
|
|
|
|
inline int MemoryInfo::GetDeviceId() const {
|
|
int id = 0;
|
|
ThrowOnError(GetApi().MemoryInfoGetId(*this, &id));
|
|
return id;
|
|
}
|
|
|
|
inline OrtMemType MemoryInfo::GetMemoryType() const {
|
|
OrtMemType type;
|
|
ThrowOnError(GetApi().MemoryInfoGetMemType(*this, &type));
|
|
return type;
|
|
}
|
|
|
|
inline bool MemoryInfo::operator==(const MemoryInfo& o) const {
|
|
int comp_result = 0;
|
|
ThrowOnError(Ort::GetApi().CompareMemoryInfo(*this, o, &comp_result));
|
|
return comp_result == 0;
|
|
}
|
|
|
|
inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
|
|
OrtMemoryInfo* p;
|
|
ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
|
|
return MemoryInfo(p);
|
|
}
|
|
|
|
inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
|
|
ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &p_));
|
|
}
|
|
|
|
inline Allocator::Allocator(const Session& sess, const MemoryInfo& mem_info) {
|
|
ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &p_));
|
|
}
|
|
|
|
inline void* Allocator::Alloc(size_t size) const {
|
|
void* out = nullptr;
|
|
ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out));
|
|
return out;
|
|
}
|
|
|
|
inline MemoryAllocation Ort::Allocator::GetAllocation(size_t size) {
|
|
void* out = nullptr;
|
|
ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out));
|
|
MemoryAllocation result(p_, out, size);
|
|
return result;
|
|
}
|
|
|
|
inline void Allocator::Free(void* p) const {
|
|
ThrowOnError(GetApi().AllocatorFree(p_, p));
|
|
}
|
|
|
|
inline Unowned<const MemoryInfo> Allocator::GetInfo() const {
|
|
const OrtMemoryInfo* out = nullptr;
|
|
ThrowOnError(GetApi().AllocatorGetInfo(p_, &out));
|
|
return Unowned<const MemoryInfo>(const_cast<OrtMemoryInfo*>(out));
|
|
}
|
|
|
|
inline IoBinding::IoBinding(Session& session) {
|
|
ThrowOnError(GetApi().CreateIoBinding(session, &p_));
|
|
}
|
|
|
|
inline void IoBinding::BindInput(const char* name, const Value& value) {
|
|
ThrowOnError(GetApi().BindInput(p_, name, value));
|
|
}
|
|
|
|
inline void IoBinding::BindOutput(const char* name, const Value& value) {
|
|
ThrowOnError(GetApi().BindOutput(p_, name, value));
|
|
}
|
|
|
|
inline void IoBinding::BindOutput(const char* name, const MemoryInfo& mem_info) {
|
|
ThrowOnError(GetApi().BindOutputToDevice(p_, name, mem_info));
|
|
}
|
|
|
|
inline std::vector<std::string> IoBinding::GetOutputNamesHelper(OrtAllocator* allocator) const {
|
|
std::vector<std::string> result;
|
|
auto free_fn = [allocator](void* p) { if (p) allocator->Free(allocator, p); };
|
|
using Ptr = std::unique_ptr<void, decltype(free_fn)>;
|
|
|
|
char* buffer = nullptr;
|
|
size_t* lengths = nullptr;
|
|
size_t count = 0;
|
|
ThrowOnError(GetApi().GetBoundOutputNames(p_, allocator, &buffer, &lengths, &count));
|
|
|
|
if (count == 0) {
|
|
return result;
|
|
}
|
|
|
|
Ptr buffer_g(buffer, free_fn);
|
|
Ptr lengths_g(lengths, free_fn);
|
|
|
|
result.reserve(count);
|
|
for (size_t i = 0; i < count; ++i) {
|
|
auto sz = *lengths;
|
|
result.emplace_back(buffer, sz);
|
|
buffer += sz;
|
|
++lengths;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
inline std::vector<std::string> IoBinding::GetOutputNames() const {
|
|
AllocatorWithDefaultOptions allocator;
|
|
return GetOutputNamesHelper(allocator);
|
|
}
|
|
|
|
inline std::vector<std::string> IoBinding::GetOutputNames(Allocator& allocator) const {
|
|
return GetOutputNamesHelper(allocator);
|
|
}
|
|
|
|
inline std::vector<Value> Ort::IoBinding::GetOutputValuesHelper(OrtAllocator* allocator) const {
|
|
std::vector<Value> result;
|
|
size_t owned = 0;
|
|
size_t output_count = 0;
|
|
// Lambda to release the buffer when no longer needed and
|
|
// make sure that we destroy all instances on exception
|
|
auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
|
|
if (buffer) {
|
|
while (owned < output_count) {
|
|
auto* p = buffer + owned++;
|
|
GetApi().ReleaseValue(*p);
|
|
}
|
|
allocator->Free(allocator, buffer);
|
|
}
|
|
};
|
|
using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
|
|
|
|
OrtValue** output_buffer = nullptr;
|
|
ThrowOnError(GetApi().GetBoundOutputValues(p_, allocator, &output_buffer, &output_count));
|
|
if (output_count == 0) {
|
|
return result;
|
|
}
|
|
|
|
Ptr buffer_g(output_buffer, free_fn);
|
|
|
|
result.reserve(output_count);
|
|
for (size_t i = 0; i < output_count; ++i) {
|
|
result.emplace_back(output_buffer[i]);
|
|
++owned;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
inline std::vector<Value> Ort::IoBinding::GetOutputValues(Allocator& allocator) const {
|
|
return GetOutputValuesHelper(allocator);
|
|
}
|
|
|
|
inline std::vector<Value> Ort::IoBinding::GetOutputValues() const {
|
|
AllocatorWithDefaultOptions allocator;
|
|
return GetOutputValuesHelper(allocator);
|
|
}
|
|
|
|
inline void IoBinding::ClearBoundInputs() {
|
|
GetApi().ClearBoundInputs(p_);
|
|
}
|
|
|
|
inline void IoBinding::ClearBoundOutputs() {
|
|
GetApi().ClearBoundOutputs(p_);
|
|
}
|
|
|
|
inline void IoBinding::SynchronizeInputs() {
|
|
ThrowOnError(GetApi().SynchronizeBoundInputs(p_));
|
|
}
|
|
|
|
inline void IoBinding::SynchronizeOutputs() {
|
|
ThrowOnError(GetApi().SynchronizeBoundOutputs(p_));
|
|
}
|
|
|
|
inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
|
|
ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
|
|
}
|
|
|
|
inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
|
|
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
|
|
if (strcmp(logid, "onnxruntime-node") == 0) {
|
|
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
|
} else {
|
|
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
|
}
|
|
}
|
|
|
|
inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
|
|
ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
|
|
if (strcmp(logid, "onnxruntime-node") == 0) {
|
|
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
|
} else {
|
|
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
|
}
|
|
}
|
|
|
|
inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
|
|
ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
|
|
if (strcmp(logid, "onnxruntime-node") == 0) {
|
|
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
|
} else {
|
|
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
|
}
|
|
}
|
|
|
|
inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
|
|
OrtLoggingLevel logging_level, _In_ const char* logid) {
|
|
ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
|
|
if (strcmp(logid, "onnxruntime-node") == 0) {
|
|
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
|
} else {
|
|
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
|
}
|
|
}
|
|
|
|
inline Env& Env::EnableTelemetryEvents() {
|
|
ThrowOnError(GetApi().EnableTelemetryEvents(p_));
|
|
return *this;
|
|
}
|
|
|
|
inline Env& Env::DisableTelemetryEvents() {
|
|
ThrowOnError(GetApi().DisableTelemetryEvents(p_));
|
|
return *this;
|
|
}
|
|
|
|
inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
|
|
ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
|
|
return *this;
|
|
}
|
|
|
|
inline CustomOpDomain::CustomOpDomain(const char* domain) {
|
|
ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
|
|
}
|
|
|
|
inline void CustomOpDomain::Add(OrtCustomOp* op) {
|
|
ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
|
|
}
|
|
|
|
inline RunOptions::RunOptions() {
|
|
ThrowOnError(GetApi().CreateRunOptions(&p_));
|
|
}
|
|
|
|
inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
|
|
ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
|
|
return *this;
|
|
}
|
|
|
|
inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
|
|
ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
|
|
return *this;
|
|
}
|
|
|
|
inline int RunOptions::GetRunLogVerbosityLevel() const {
|
|
int out;
|
|
ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline int RunOptions::GetRunLogSeverityLevel() const {
|
|
int out;
|
|
ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
|
|
ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
|
|
return *this;
|
|
}
|
|
|
|
inline const char* RunOptions::GetRunTag() const {
|
|
const char* out;
|
|
ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
|
|
ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
|
|
return *this;
|
|
}
|
|
|
|
inline RunOptions& RunOptions::SetTerminate() {
|
|
ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
|
|
return *this;
|
|
}
|
|
|
|
inline RunOptions& RunOptions::UnsetTerminate() {
|
|
ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions::SessionOptions() {
|
|
ThrowOnError(GetApi().CreateSessionOptions(&p_));
|
|
}
|
|
|
|
inline SessionOptions SessionOptions::Clone() const {
|
|
OrtSessionOptions* out;
|
|
ThrowOnError(GetApi().CloneSessionOptions(p_, &out));
|
|
return SessionOptions{out};
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::SetIntraOpNumThreads(int intra_op_num_threads) {
|
|
ThrowOnError(GetApi().SetIntraOpNumThreads(p_, intra_op_num_threads));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::SetInterOpNumThreads(int inter_op_num_threads) {
|
|
ThrowOnError(GetApi().SetInterOpNumThreads(p_, inter_op_num_threads));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
|
|
ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(p_, graph_optimization_level));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
|
|
ThrowOnError(GetApi().SetOptimizedModelFilePath(p_, optimized_model_filepath));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
|
|
ThrowOnError(GetApi().EnableProfiling(p_, profile_file_prefix));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::DisableProfiling() {
|
|
ThrowOnError(GetApi().DisableProfiling(p_));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::EnableOrtCustomOps() {
|
|
ThrowOnError(GetApi().EnableOrtCustomOps(p_));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::EnableMemPattern() {
|
|
ThrowOnError(GetApi().EnableMemPattern(p_));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::DisableMemPattern() {
|
|
ThrowOnError(GetApi().DisableMemPattern(p_));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::EnableCpuMemArena() {
|
|
ThrowOnError(GetApi().EnableCpuMemArena(p_));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::DisableCpuMemArena() {
|
|
ThrowOnError(GetApi().DisableCpuMemArena(p_));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::SetExecutionMode(ExecutionMode execution_mode) {
|
|
ThrowOnError(GetApi().SetSessionExecutionMode(p_, execution_mode));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::SetLogId(const char* logid) {
|
|
ThrowOnError(GetApi().SetSessionLogId(p_, logid));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::SetLogSeverityLevel(int level) {
|
|
ThrowOnError(GetApi().SetSessionLogSeverityLevel(p_, level));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::Add(OrtCustomOpDomain* custom_op_domain) {
|
|
ThrowOnError(GetApi().AddCustomOpDomain(p_, custom_op_domain));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::AddConfigEntry(const char* config_key, const char* config_value) {
|
|
ThrowOnError(GetApi().AddSessionConfigEntry(p_, config_key, config_value));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::AddInitializer(const char* name, const OrtValue* ort_val) {
|
|
ThrowOnError(GetApi().AddInitializer(p_, name, ort_val));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
|
|
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(p_, &provider_options));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
|
|
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(p_, &provider_options));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
|
|
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(p_, &provider_options));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
|
|
ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
|
|
ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
|
|
ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
|
|
return *this;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
|
|
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(p_, &provider_options));
|
|
return *this;
|
|
}
|
|
|
|
inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
|
|
ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_));
|
|
}
|
|
|
|
inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
|
|
OrtPrepackedWeightsContainer* prepacked_weights_container) {
|
|
ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &p_));
|
|
}
|
|
|
|
inline Session::Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
|
|
ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &p_));
|
|
}
|
|
|
|
inline std::vector<Value> Session::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
|
const char* const* output_names, size_t output_names_count) {
|
|
std::vector<Ort::Value> output_values;
|
|
for (size_t i = 0; i < output_names_count; i++)
|
|
output_values.emplace_back(nullptr);
|
|
Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_names_count);
|
|
return output_values;
|
|
}
|
|
|
|
inline void Session::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
|
const char* const* output_names, Value* output_values, size_t output_count) {
|
|
static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
|
|
auto ort_input_values = reinterpret_cast<const OrtValue**>(const_cast<Value*>(input_values));
|
|
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
|
|
ThrowOnError(GetApi().Run(p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
|
|
}
|
|
|
|
inline void Session::Run(const RunOptions& run_options, const IoBinding& io_binding) {
|
|
ThrowOnError(GetApi().RunWithBinding(p_, run_options, io_binding));
|
|
}
|
|
|
|
inline size_t Session::GetInputCount() const {
|
|
size_t out;
|
|
ThrowOnError(GetApi().SessionGetInputCount(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline size_t Session::GetOutputCount() const {
|
|
size_t out;
|
|
ThrowOnError(GetApi().SessionGetOutputCount(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline size_t Session::GetOverridableInitializerCount() const {
|
|
size_t out;
|
|
ThrowOnError(GetApi().SessionGetOverridableInitializerCount(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline char* Session::GetInputName(size_t index, OrtAllocator* allocator) const {
|
|
char* out;
|
|
ThrowOnError(GetApi().SessionGetInputName(p_, index, allocator, &out));
|
|
return out;
|
|
}
|
|
|
|
inline char* Session::GetOutputName(size_t index, OrtAllocator* allocator) const {
|
|
char* out;
|
|
ThrowOnError(GetApi().SessionGetOutputName(p_, index, allocator, &out));
|
|
return out;
|
|
}
|
|
|
|
inline char* Session::GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const {
|
|
char* out;
|
|
ThrowOnError(GetApi().SessionGetOverridableInitializerName(p_, index, allocator, &out));
|
|
return out;
|
|
}
|
|
|
|
inline char* Session::EndProfiling(OrtAllocator* allocator) const {
|
|
char* out;
|
|
ThrowOnError(GetApi().SessionEndProfiling(p_, allocator, &out));
|
|
return out;
|
|
}
|
|
|
|
inline uint64_t Session::GetProfilingStartTimeNs() const {
|
|
uint64_t out;
|
|
ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline ModelMetadata Session::GetModelMetadata() const {
|
|
OrtModelMetadata* out;
|
|
ThrowOnError(GetApi().SessionGetModelMetadata(p_, &out));
|
|
return ModelMetadata{out};
|
|
}
|
|
|
|
inline char* ModelMetadata::GetProducerName(OrtAllocator* allocator) const {
|
|
char* out;
|
|
ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
|
|
return out;
|
|
}
|
|
|
|
inline char* ModelMetadata::GetGraphName(OrtAllocator* allocator) const {
|
|
char* out;
|
|
ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
|
|
return out;
|
|
}
|
|
|
|
inline char* ModelMetadata::GetDomain(OrtAllocator* allocator) const {
|
|
char* out;
|
|
ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
|
|
return out;
|
|
}
|
|
|
|
inline char* ModelMetadata::GetDescription(OrtAllocator* allocator) const {
|
|
char* out;
|
|
ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
|
|
return out;
|
|
}
|
|
|
|
inline char* ModelMetadata::GetGraphDescription(OrtAllocator* allocator) const {
|
|
char* out;
|
|
ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
|
|
return out;
|
|
}
|
|
|
|
inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const {
|
|
char* out;
|
|
ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
|
|
return out;
|
|
}
|
|
|
|
inline char** ModelMetadata::GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const {
|
|
char** out;
|
|
ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
|
|
return out;
|
|
}
|
|
|
|
inline int64_t ModelMetadata::GetVersion() const {
|
|
int64_t out;
|
|
ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline TypeInfo Session::GetInputTypeInfo(size_t index) const {
|
|
OrtTypeInfo* out;
|
|
ThrowOnError(GetApi().SessionGetInputTypeInfo(p_, index, &out));
|
|
return TypeInfo{out};
|
|
}
|
|
|
|
inline TypeInfo Session::GetOutputTypeInfo(size_t index) const {
|
|
OrtTypeInfo* out;
|
|
ThrowOnError(GetApi().SessionGetOutputTypeInfo(p_, index, &out));
|
|
return TypeInfo{out};
|
|
}
|
|
|
|
inline TypeInfo Session::GetOverridableInitializerTypeInfo(size_t index) const {
|
|
OrtTypeInfo* out;
|
|
ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(p_, index, &out));
|
|
return TypeInfo{out};
|
|
}
|
|
|
|
inline ONNXTensorElementDataType TensorTypeAndShapeInfo::GetElementType() const {
|
|
ONNXTensorElementDataType out;
|
|
ThrowOnError(GetApi().GetTensorElementType(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline size_t TensorTypeAndShapeInfo::GetElementCount() const {
|
|
size_t out;
|
|
ThrowOnError(GetApi().GetTensorShapeElementCount(p_, &out));
|
|
return static_cast<size_t>(out);
|
|
}
|
|
|
|
inline size_t TensorTypeAndShapeInfo::GetDimensionsCount() const {
|
|
size_t out;
|
|
ThrowOnError(GetApi().GetDimensionsCount(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline void TensorTypeAndShapeInfo::GetDimensions(int64_t* values, size_t values_count) const {
|
|
ThrowOnError(GetApi().GetDimensions(p_, values, values_count));
|
|
}
|
|
|
|
inline void TensorTypeAndShapeInfo::GetSymbolicDimensions(const char** values, size_t values_count) const {
|
|
ThrowOnError(GetApi().GetSymbolicDimensions(p_, values, values_count));
|
|
}
|
|
|
|
inline std::vector<int64_t> TensorTypeAndShapeInfo::GetShape() const {
|
|
std::vector<int64_t> out(GetDimensionsCount(), 0);
|
|
GetDimensions(out.data(), out.size());
|
|
return out;
|
|
}
|
|
|
|
inline Unowned<TensorTypeAndShapeInfo> TypeInfo::GetTensorTypeAndShapeInfo() const {
|
|
const OrtTensorTypeAndShapeInfo* out;
|
|
ThrowOnError(GetApi().CastTypeInfoToTensorInfo(p_, &out));
|
|
return Unowned<TensorTypeAndShapeInfo>(const_cast<OrtTensorTypeAndShapeInfo*>(out));
|
|
}
|
|
|
|
inline Unowned<SequenceTypeInfo> TypeInfo::GetSequenceTypeInfo() const {
|
|
const OrtSequenceTypeInfo* out;
|
|
ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(p_, &out));
|
|
return Unowned<SequenceTypeInfo>{const_cast<OrtSequenceTypeInfo*>(out)};
|
|
}
|
|
|
|
inline TypeInfo SequenceTypeInfo::GetSequenceElementType() const {
|
|
OrtTypeInfo* output;
|
|
ThrowOnError(GetApi().GetSequenceElementType(p_, &output));
|
|
return TypeInfo{output};
|
|
}
|
|
|
|
inline Unowned<MapTypeInfo> TypeInfo::GetMapTypeInfo() const {
|
|
const OrtMapTypeInfo* out;
|
|
ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(p_, &out));
|
|
return Unowned<MapTypeInfo>{const_cast<OrtMapTypeInfo*>(out)};
|
|
}
|
|
|
|
inline ONNXTensorElementDataType MapTypeInfo::GetMapKeyType() const {
|
|
ONNXTensorElementDataType out;
|
|
ThrowOnError(GetApi().GetMapKeyType(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline TypeInfo MapTypeInfo::GetMapValueType() const {
|
|
OrtTypeInfo* output;
|
|
ThrowOnError(GetApi().GetMapValueType(p_, &output));
|
|
return TypeInfo{output};
|
|
}
|
|
|
|
inline ONNXType TypeInfo::GetONNXType() const {
|
|
ONNXType out;
|
|
ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
template <typename T>
|
|
inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
|
|
return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
|
|
}
|
|
|
|
inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
|
|
ONNXTensorElementDataType type) {
|
|
OrtValue* out;
|
|
ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
|
|
return Value{out};
|
|
}
|
|
|
|
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
template <typename T>
|
|
inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
|
|
const Shape& values_shape) {
|
|
return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
|
|
}
|
|
|
|
inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
|
|
const Shape& values_shape, ONNXTensorElementDataType type) {
|
|
OrtValue* out;
|
|
ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
|
|
values_shape.shape, values_shape.shape_len, type, &out));
|
|
return Value{out};
|
|
}
|
|
|
|
inline void Value::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
|
|
const int64_t* indices_data, size_t indices_num) {
|
|
ThrowOnError(GetApi().FillSparseTensorCoo(p_, mem_info, values_param.values_shape,
|
|
values_param.values_shape_len, values_param.data.p_data,
|
|
indices_data, indices_num));
|
|
}
|
|
|
|
inline void Value::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
|
|
const OrtSparseValuesParam& values,
|
|
const int64_t* inner_indices_data, size_t inner_indices_num,
|
|
const int64_t* outer_indices_data, size_t outer_indices_num) {
|
|
ThrowOnError(GetApi().FillSparseTensorCsr(p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
|
|
inner_indices_data, inner_indices_num,
|
|
outer_indices_data, outer_indices_num));
|
|
}
|
|
|
|
inline void Value::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
|
|
const OrtSparseValuesParam& values,
|
|
const Shape& indices_shape,
|
|
const int32_t* indices_data) {
|
|
ThrowOnError(GetApi().FillSparseTensorBlockSparse(p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
|
|
indices_shape.shape, indices_shape.shape_len,
|
|
indices_data));
|
|
}
|
|
|
|
inline void Value::UseCooIndices(int64_t* indices_data, size_t indices_num) {
|
|
ThrowOnError(GetApi().UseCooIndices(p_, indices_data, indices_num));
|
|
}
|
|
|
|
inline void Value::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
|
|
ThrowOnError(GetApi().UseCsrIndices(p_, inner_data, inner_num, outer_data, outer_num));
|
|
}
|
|
|
|
inline void Value::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
|
|
ThrowOnError(GetApi().UseBlockSparseIndices(p_, indices_shape.shape, indices_shape.shape_len, indices_data));
|
|
}
|
|
|
|
inline OrtSparseFormat Value::GetSparseFormat() const {
|
|
OrtSparseFormat format;
|
|
ThrowOnError(GetApi().GetSparseTensorFormat(p_, &format));
|
|
return format;
|
|
}
|
|
|
|
inline TensorTypeAndShapeInfo Value::GetSparseTensorValuesTypeAndShapeInfo() const {
|
|
OrtTensorTypeAndShapeInfo* output;
|
|
ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(p_, &output));
|
|
return TensorTypeAndShapeInfo{output};
|
|
}
|
|
|
|
inline TensorTypeAndShapeInfo Value::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
|
|
OrtTensorTypeAndShapeInfo* output;
|
|
ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(p_, indices_format, &output));
|
|
return TensorTypeAndShapeInfo{output};
|
|
}
|
|
|
|
template <typename T>
|
|
inline const T* Value::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
|
|
const void* out;
|
|
ThrowOnError(GetApi().GetSparseTensorIndices(p_, indices_format, &num_indices, &out));
|
|
return reinterpret_cast<const T*>(out);
|
|
}
|
|
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
|
|
|
template <typename T>
|
|
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
|
|
return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
|
|
}
|
|
|
|
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
|
|
OrtValue* out;
|
|
ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
|
|
return Value{out};
|
|
}
|
|
|
|
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
template <typename T>
|
|
inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
|
|
return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
|
|
}
|
|
|
|
inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
|
|
ONNXTensorElementDataType type) {
|
|
OrtValue* out;
|
|
ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
|
|
return Value{out};
|
|
}
|
|
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
|
|
|
inline Value Value::CreateMap(Value& keys, Value& values) {
|
|
OrtValue* out;
|
|
OrtValue* inputs[2] = {keys, values};
|
|
ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
|
|
return Value{out};
|
|
}
|
|
|
|
inline Value Value::CreateSequence(std::vector<Value>& values) {
|
|
OrtValue* out;
|
|
std::vector<OrtValue*> values_ort{values.data(), values.data() + values.size()};
|
|
ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
|
|
return Value{out};
|
|
}
|
|
|
|
template <typename T>
|
|
inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
|
|
OrtValue* out;
|
|
ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
|
|
return Value{out};
|
|
}
|
|
|
|
template <typename T>
|
|
inline void Value::GetOpaqueData(const char* domain, const char* type_name, T& out) const {
|
|
ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, p_, &out, sizeof(T)));
|
|
}
|
|
|
|
inline bool Value::IsTensor() const {
|
|
int out;
|
|
ThrowOnError(GetApi().IsTensor(p_, &out));
|
|
return out != 0;
|
|
}
|
|
|
|
inline bool Value::HasValue() const {
|
|
int out;
|
|
ThrowOnError(GetApi().HasValue(p_, &out));
|
|
return out != 0;
|
|
}
|
|
|
|
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
inline bool Value::IsSparseTensor() const {
|
|
int out;
|
|
ThrowOnError(GetApi().IsSparseTensor(p_, &out));
|
|
return out != 0;
|
|
}
|
|
#endif
|
|
|
|
inline size_t Value::GetCount() const {
|
|
size_t out;
|
|
ThrowOnError(GetApi().GetValueCount(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline Value Value::GetValue(int index, OrtAllocator* allocator) const {
|
|
OrtValue* out;
|
|
ThrowOnError(GetApi().GetValue(p_, index, allocator, &out));
|
|
return Value{out};
|
|
}
|
|
|
|
inline size_t Value::GetStringTensorDataLength() const {
|
|
size_t out;
|
|
ThrowOnError(GetApi().GetStringTensorDataLength(p_, &out));
|
|
return out;
|
|
}
|
|
|
|
inline size_t Value::GetStringTensorElementLength(size_t element_index) const {
|
|
size_t out;
|
|
ThrowOnError(GetApi().GetStringTensorElementLength(p_, element_index, &out));
|
|
return out;
|
|
}
|
|
|
|
inline void Value::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
|
|
ThrowOnError(GetApi().GetStringTensorContent(p_, buffer, buffer_length, offsets, offsets_count));
|
|
}
|
|
|
|
inline void Value::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
|
|
ThrowOnError(GetApi().GetStringTensorElement(p_, buffer_length, element_index, buffer));
|
|
}
|
|
|
|
inline void Value::FillStringTensor(const char* const* s, size_t s_len) {
|
|
ThrowOnError(GetApi().FillStringTensor(p_, s, s_len));
|
|
}
|
|
|
|
inline void Value::FillStringTensorElement(const char* s, size_t index) {
|
|
ThrowOnError(GetApi().FillStringTensorElement(p_, s, index));
|
|
}
|
|
|
|
template <typename T>
|
|
T* Value::GetTensorMutableData() {
|
|
T* out;
|
|
ThrowOnError(GetApi().GetTensorMutableData(p_, (void**)&out));
|
|
return out;
|
|
}
|
|
|
|
template <typename T>
|
|
const T* Value::GetTensorData() const {
|
|
T* out;
|
|
ThrowOnError(GetApi().GetTensorMutableData(p_, (void**)&out));
|
|
return out;
|
|
}
|
|
|
|
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
template <typename T>
|
|
inline const T* Value::GetSparseTensorValues() const {
|
|
const void* out;
|
|
ThrowOnError(GetApi().GetSparseTensorValues(p_, &out));
|
|
return reinterpret_cast<const T*>(out);
|
|
}
|
|
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
|
|
|
template <typename T>
|
|
inline T& Value::At(const std::vector<int64_t>& location) {
|
|
static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
|
|
T* out;
|
|
ThrowOnError(GetApi().TensorAt(p_, location.data(), location.size(), (void**)&out));
|
|
return *out;
|
|
}
|
|
|
|
inline TypeInfo Value::GetTypeInfo() const {
|
|
OrtTypeInfo* output;
|
|
ThrowOnError(GetApi().GetTypeInfo(p_, &output));
|
|
return TypeInfo{output};
|
|
}
|
|
|
|
inline TensorTypeAndShapeInfo Value::GetTensorTypeAndShapeInfo() const {
|
|
OrtTensorTypeAndShapeInfo* output;
|
|
ThrowOnError(GetApi().GetTensorTypeAndShape(p_, &output));
|
|
return TensorTypeAndShapeInfo{output};
|
|
}
|
|
|
|
//
|
|
// Custom OP API Inlines
|
|
//
|
|
inline void CustomOpApi::ThrowOnError(OrtStatus* status) {
|
|
Ort::ThrowOnError(api_, status);
|
|
}
|
|
|
|
template <>
|
|
inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
|
float out;
|
|
ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
|
|
return out;
|
|
}
|
|
|
|
template <>
|
|
inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
|
int64_t out;
|
|
ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
|
|
return out;
|
|
}
|
|
|
|
template <>
|
|
inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
|
size_t size = 0;
|
|
std::string out;
|
|
|
|
// Feed nullptr for the data buffer to query the true size of the string attribute
|
|
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
|
|
|
|
if (status == nullptr) {
|
|
out.resize(size);
|
|
ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
|
|
out.resize(size - 1); // remove the terminating character '\0'
|
|
} else {
|
|
ThrowOnError(status);
|
|
}
|
|
return out;
|
|
}
|
|
|
|
template <>
|
|
inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
|
size_t size = 0;
|
|
std::vector<float> out;
|
|
|
|
// Feed nullptr for the data buffer to query the true size of the attribute
|
|
OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
|
|
|
|
if (status == nullptr) {
|
|
out.resize(size);
|
|
ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
|
|
} else {
|
|
ThrowOnError(status);
|
|
}
|
|
return out;
|
|
}
|
|
|
|
template <>
|
|
inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
|
size_t size = 0;
|
|
std::vector<int64_t> out;
|
|
|
|
// Feed nullptr for the data buffer to query the true size of the attribute
|
|
OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
|
|
|
|
if (status == nullptr) {
|
|
out.resize(size);
|
|
ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
|
|
} else {
|
|
ThrowOnError(status);
|
|
}
|
|
return out;
|
|
}
|
|
inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) {
|
|
OrtTensorTypeAndShapeInfo* out;
|
|
ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
|
|
return out;
|
|
}
|
|
|
|
inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
|
|
size_t out;
|
|
ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
|
|
return out;
|
|
}
|
|
|
|
inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) {
|
|
ONNXTensorElementDataType out;
|
|
ThrowOnError(api_.GetTensorElementType(info, &out));
|
|
return out;
|
|
}
|
|
|
|
inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
|
|
size_t out;
|
|
ThrowOnError(api_.GetDimensionsCount(info, &out));
|
|
return out;
|
|
}
|
|
|
|
inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) {
|
|
ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
|
|
}
|
|
|
|
inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) {
|
|
ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
|
|
}
|
|
|
|
template <typename T>
|
|
inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) {
|
|
T* data;
|
|
ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
|
|
return data;
|
|
}
|
|
|
|
inline const OrtMemoryInfo* CustomOpApi::GetTensorMemoryInfo(_In_ const OrtValue* value) {
|
|
const OrtMemoryInfo* mem_info;
|
|
ThrowOnError(api_.GetTensorMemoryInfo(value, &mem_info));
|
|
return mem_info;
|
|
}
|
|
|
|
template <typename T>
|
|
inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) {
|
|
return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
|
|
}
|
|
|
|
inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) {
|
|
std::vector<int64_t> output(GetDimensionsCount(info));
|
|
GetDimensions(info, output.data(), output.size());
|
|
return output;
|
|
}
|
|
|
|
inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) {
|
|
api_.ReleaseTensorTypeAndShapeInfo(input);
|
|
}
|
|
|
|
inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) {
|
|
size_t out;
|
|
ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
|
|
return out;
|
|
}
|
|
|
|
inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) {
|
|
const OrtValue* out;
|
|
ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
|
|
return out;
|
|
}
|
|
|
|
inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) {
|
|
size_t out;
|
|
ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
|
|
return out;
|
|
}
|
|
|
|
inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
|
|
_In_ const int64_t* dim_values, size_t dim_count) {
|
|
OrtValue* out;
|
|
ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
|
|
return out;
|
|
}
|
|
|
|
inline void* CustomOpApi::KernelContext_GetGPUComputeStream(const OrtKernelContext* context) {
|
|
void* out;
|
|
ThrowOnError(api_.KernelContext_GetGPUComputeStream(context, &out));
|
|
return out;
|
|
}
|
|
|
|
inline SessionOptions& SessionOptions::DisablePerSessionThreads() {
|
|
ThrowOnError(GetApi().DisablePerSessionThreads(p_));
|
|
return *this;
|
|
}
|
|
|
|
inline std::vector<std::string> GetAvailableProviders() {
|
|
int len;
|
|
char** providers;
|
|
const OrtApi& api = GetApi();
|
|
ThrowOnError(api.GetAvailableProviders(&providers, &len));
|
|
std::vector<std::string> available_providers(providers, providers + len);
|
|
ThrowOnError(api.ReleaseAvailableProviders(providers, len));
|
|
return available_providers;
|
|
}
|
|
|
|
SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
|
|
|
|
} // namespace Ort
|
|
|