// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // Summary // The header has APIs to save custom op authors the trouble of defining schemas, // which will be inferred by functions' signature, as long as their argument list has types supported here. // Input could be: // 1. Tensor of onnx data types. // 2. Span of onnx data types. // 3. Scalar of onnx data types. // A input could be optional if indicated as std::optional<...>. // For an output, it must be a tensor of onnx data types. // Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op. // For concrete examples, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/". // Note - all APIs in this header are ABI. #pragma once #include "onnxruntime_cxx_api.h" #include <optional> #include <numeric> #include <functional> #include <unordered_set> namespace Ort { namespace Custom { class ArgBase { public: ArgBase(OrtKernelContext* ctx, size_t indice, bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {} virtual ~ArgBase() {}; protected: struct KernelContext ctx_; size_t indice_; bool is_input_; }; using ArgPtr = std::unique_ptr<Custom::ArgBase>; using ArgPtrs = std::vector<ArgPtr>; class TensorBase : public ArgBase { public: TensorBase(OrtKernelContext* ctx, size_t indice, bool is_input) : ArgBase(ctx, indice, is_input) {} operator bool() const { return shape_.has_value(); } const std::vector<int64_t>& Shape() const { if (!shape_.has_value()) { ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } return shape_.value(); } ONNXTensorElementDataType Type() const { return type_; } int64_t NumberOfElement() const { if (shape_.has_value()) { return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>()); } else { return 0; } } std::string Shape2Str() const { if (shape_.has_value()) { std::string shape_str; for (const auto& dim : *shape_) { shape_str.append(std::to_string(dim)); shape_str.append(", "); } return shape_str; } else { return "empty"; } } bool IsCpuTensor() const { return strcmp("Cpu", mem_type_) == 0; } virtual const void* DataRaw() const = 0; virtual size_t SizeInBytes() const = 0; protected: std::optional<std::vector<int64_t>> shape_; ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; const char* mem_type_ = "Cpu"; }; template <typename T> struct Span { const T* data_ = {}; size_t size_ = {}; void Assign(const T* data, size_t size) { data_ = data; size_ = size; } size_t size() const { return size_; } T operator[](size_t indice) const { return data_[indice]; } const T* data() const { return data_; } }; template <typename T> class Tensor : public TensorBase { public: using TT = typename std::remove_reference<T>::type; Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { if (is_input_) { if (indice >= ctx_.GetInputCount()) { ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); } const_value_ = ctx_.GetInput(indice); auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo(); shape_ = type_shape_info.GetShape(); } } const TT* Data() const { return reinterpret_cast<const TT*>(const_value_.GetTensorRawData()); } TT* Allocate(const std::vector<int64_t>& shape) { shape_ = shape; if (!data_) { shape_ = shape; data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData<TT>(); } return data_; } static TT GetT() { return (TT)0; } const Span<T>& AsSpan() { if (!shape_.has_value() || shape_->size() != 1) { ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } span_.Assign(Data(), static_cast<size_t>((*shape_)[0])); return span_; } const T& AsScalar() { if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) { ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } return *Data(); } const void* DataRaw() const override { return reinterpret_cast<const void*>(Data()); } size_t SizeInBytes() const override { return sizeof(TT) * static_cast<size_t>(NumberOfElement()); } private: ConstValue const_value_; // for input TT* data_{}; // for output Span<T> span_; }; template <> class Tensor<std::string> : public TensorBase { public: using strings = std::vector<std::string>; Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { if (is_input_) { if (indice >= ctx_.GetInputCount()) { ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); } auto const_value = ctx_.GetInput(indice); auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); shape_ = type_shape_info.GetShape(); auto num_chars = const_value.GetStringTensorDataLength(); // note - there will be copy ... auto num_strings = static_cast<size_t>(NumberOfElement()); if (num_strings) { std::vector<char> chars(num_chars + 1, '\0'); std::vector<size_t> offsets(num_strings); const_value.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size()); auto upper_bound = num_strings - 1; input_strings_.resize(num_strings); for (size_t i = upper_bound;; --i) { if (i < upper_bound) { chars[offsets[i + 1]] = '\0'; } input_strings_[i] = chars.data() + offsets[i]; if (0 == i) { break; } } } } } const strings& Data() const { return input_strings_; } const void* DataRaw() const override { if (input_strings_.size() != 1) { ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); } return reinterpret_cast<const void*>(input_strings_[0].c_str()); } size_t SizeInBytes() const override { if (input_strings_.size() != 1) { ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); } return input_strings_[0].size(); } void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) { shape_ = dims; std::vector<const char*> raw; for (const auto& s : ss) { raw.push_back(s.data()); } auto output = ctx_.GetOutput(indice_, dims.data(), dims.size()); // note - there will be copy ... output.FillStringTensor(raw.data(), raw.size()); } const Span<std::string>& AsSpan() { ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } const std::string& AsScalar() { if (input_strings_.size() != 1) { ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } return input_strings_[0]; } private: std::vector<std::string> input_strings_; // for input }; template <> class Tensor<std::string_view> : public TensorBase { public: using strings = std::vector<std::string>; using string_views = std::vector<std::string_view>; Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { if (is_input_) { if (indice >= ctx_.GetInputCount()) { ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); } auto const_value = ctx_.GetInput(indice); auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); shape_ = type_shape_info.GetShape(); auto num_chars = const_value.GetStringTensorDataLength(); chars_.resize(num_chars + 1, '\0'); auto num_strings = static_cast<size_t>(NumberOfElement()); if (num_strings) { std::vector<size_t> offsets(num_strings); const_value.GetStringTensorContent(static_cast<void*>(chars_.data()), num_chars, offsets.data(), offsets.size()); offsets.push_back(num_chars); for (size_t i = 0; i < num_strings; ++i) { input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]); } } } } const string_views& Data() const { return input_string_views_; } const void* DataRaw() const override { if (input_string_views_.size() != 1) { ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); } return reinterpret_cast<const void*>(input_string_views_[0].data()); } size_t SizeInBytes() const override { if (input_string_views_.size() != 1) { ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); } return input_string_views_[0].size(); } void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) { shape_ = dims; std::vector<const char*> raw; for (const auto& s : ss) { raw.push_back(s.data()); } auto output = ctx_.GetOutput(indice_, dims.data(), dims.size()); // note - there will be copy ... output.FillStringTensor(raw.data(), raw.size()); } const Span<std::string_view>& AsSpan() { ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } std::string_view AsScalar() { if (input_string_views_.size() != 1) { ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } return input_string_views_[0]; } private: std::vector<char> chars_; // for input std::vector<std::string_view> input_string_views_; // for input }; using TensorPtr = std::unique_ptr<Custom::TensorBase>; using TensorPtrs = std::vector<TensorPtr>; struct TensorArray : public ArgBase { TensorArray(OrtKernelContext* ctx, size_t start_indice, bool is_input) : ArgBase(ctx, start_indice, is_input) { if (is_input) { auto input_count = ctx_.GetInputCount(); for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) { auto const_value = ctx_.GetInput(start_indice); auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); auto type = type_shape_info.GetElementType(); TensorPtr tensor; switch (type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: tensor = std::make_unique<Custom::Tensor<bool>>(ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: tensor = std::make_unique<Custom::Tensor<float>>(ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: tensor = std::make_unique<Custom::Tensor<double>>(ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: tensor = std::make_unique<Custom::Tensor<uint8_t>>(ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: tensor = std::make_unique<Custom::Tensor<int8_t>>(ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: tensor = std::make_unique<Custom::Tensor<uint16_t>>(ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: tensor = std::make_unique<Custom::Tensor<int16_t>>(ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: tensor = std::make_unique<Custom::Tensor<uint32_t>>(ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: tensor = std::make_unique<Custom::Tensor<int32_t>>(ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: tensor = std::make_unique<Custom::Tensor<uint64_t>>(ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: tensor = std::make_unique<Custom::Tensor<int64_t>>(ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: tensor = std::make_unique<Custom::Tensor<std::string>>(ctx, ith_input, true); break; default: ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); break; } tensors_.emplace_back(tensor.release()); } // for } } template <typename T> T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) { // ith_output is the indice of output relative to the tensor array // indice_ + ith_output is the indice relative to context auto tensor = std::make_unique<Tensor<T>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false); auto raw_output = tensor.get()->Allocate(shape); tensors_.emplace_back(tensor.release()); return raw_output; } Tensor<std::string>& AllocateStringTensor(size_t ith_output) { // ith_output is the indice of output relative to the tensor array // indice_ + ith_output is the indice relative to context auto tensor = std::make_unique<Tensor<std::string>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false); Tensor<std::string>& output = *tensor; tensors_.emplace_back(tensor.release()); return output; } size_t Size() const { return tensors_.size(); } const TensorPtr& operator[](size_t ith_input) const { // ith_input is the indice of output relative to the tensor array return tensors_.at(ith_input); } private: TensorPtrs tensors_; }; using Variadic = TensorArray; /* Note: OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core. The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so: 1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierarchy. 2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp, hence memory could still be recycled properly. Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety. */ struct OrtLiteCustomOp : public OrtCustomOp { using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>; using OptionalFloatTensor = std::optional<Custom::Tensor<float>>; // CreateTuple template <size_t ith_input, size_t ith_output, typename... Ts> static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) { return std::make_tuple(); } template <size_t ith_input, size_t ith_output, typename T, typename... Ts> static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { std::tuple<T> current = std::tuple<OrtKernelContext*>{context}; auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } template <size_t ith_input, size_t ith_output, typename T, typename... Ts> static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { std::tuple<T> current = std::tuple<OrtKernelContext&>{*context}; auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #ifdef ORT_CUDA_CTX template <size_t ith_input, size_t ith_output, typename T, typename... Ts> static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { thread_local CudaContext cuda_context; cuda_context.Init(*context); std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context}; auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #endif #ifdef ORT_ROCM_CTX template <size_t ith_input, size_t ith_output, typename T, typename... Ts> static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { thread_local RocmContext rocm_context; rocm_context.Init(*context); std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context}; auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #endif template <size_t ith_input, size_t ith_output, typename T, typename... Ts> static typename std::enable_if<std::is_same<T, const TensorArray*>::value, std::tuple<T, Ts...>>::type CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { args.push_back(std::make_unique<TensorArray>(context, ith_input, true)); std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } template <size_t ith_input, size_t ith_output, typename T, typename... Ts> static typename std::enable_if<std::is_same<T, const TensorArray&>::value, std::tuple<T, Ts...>>::type CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { args.push_back(std::make_unique<TensorArray>(context, ith_input, true)); std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } template <size_t ith_input, size_t ith_output, typename T, typename... Ts> static typename std::enable_if<std::is_same<T, TensorArray*>::value, std::tuple<T, Ts...>>::type CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { args.push_back(std::make_unique<TensorArray>(context, ith_output, false)); std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } template <size_t ith_input, size_t ith_output, typename T, typename... Ts> static typename std::enable_if<std::is_same<T, TensorArray&>::value, std::tuple<T, Ts...>>::type CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { args.push_back(std::make_unique<TensorArray>(context, ith_output, false)); std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #define CREATE_TUPLE_INPUT(data_type) \ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } \ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } \ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ if (ith_input < num_input) { \ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } else { \ std::tuple<T> current = std::tuple<T>{}; \ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } \ } \ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ if ("CPUExecutionProvider" != ep) { \ ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ } \ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \ std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } \ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ if ("CPUExecutionProvider" != ep) { \ ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ } \ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } \ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ if (ith_input < num_input) { \ if ("CPUExecutionProvider" != ep) { \ ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ } \ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \ std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } else { \ std::tuple<T> current = std::tuple<T>{}; \ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } \ } \ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ if ("CPUExecutionProvider" != ep) { \ ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ } \ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } \ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ if (ith_input < num_input) { \ if ("CPUExecutionProvider" != ep) { \ ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ } \ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } else { \ std::tuple<T> current = std::tuple<T>{}; \ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } \ } #define CREATE_TUPLE_OUTPUT(data_type) \ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } \ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } \ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ if (ith_output < num_output) { \ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } else { \ std::tuple<T> current = std::tuple<T>{}; \ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \ return std::tuple_cat(current, next); \ } \ } #define CREATE_TUPLE(data_type) \ CREATE_TUPLE_INPUT(data_type) \ CREATE_TUPLE_OUTPUT(data_type) CREATE_TUPLE(bool) CREATE_TUPLE(float) CREATE_TUPLE(Ort::Float16_t) CREATE_TUPLE(Ort::BFloat16_t) CREATE_TUPLE(double) CREATE_TUPLE(int8_t) CREATE_TUPLE(int16_t) CREATE_TUPLE(int32_t) CREATE_TUPLE(int64_t) CREATE_TUPLE(uint8_t) CREATE_TUPLE(uint16_t) CREATE_TUPLE(uint32_t) CREATE_TUPLE(uint64_t) CREATE_TUPLE(std::string) CREATE_TUPLE_INPUT(std::string_view) CREATE_TUPLE(Ort::Float8E4M3FN_t) CREATE_TUPLE(Ort::Float8E4M3FNUZ_t) CREATE_TUPLE(Ort::Float8E5M2_t) CREATE_TUPLE(Ort::Float8E5M2FNUZ_t) // ParseArgs ... template <typename... Ts> static typename std::enable_if<0 == sizeof...(Ts)>::type ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) { } template <typename T, typename... Ts> static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { ParseArgs<Ts...>(input_types, output_types); } template <typename T, typename... Ts> static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { ParseArgs<Ts...>(input_types, output_types); } #ifdef ORT_CUDA_CTX template <typename T, typename... Ts> static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { ParseArgs<Ts...>(input_types, output_types); } #endif #ifdef ORT_ROCM_CTX template <typename T, typename... Ts> static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { ParseArgs<Ts...>(input_types, output_types); } #endif template <typename T, typename... Ts> static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray&>::value>::type ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); ParseArgs<Ts...>(input_types, output_types); } template <typename T, typename... Ts> static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray*>::value>::type ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); ParseArgs<Ts...>(input_types, output_types); } template <typename T, typename... Ts> static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray&>::value>::type ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); ParseArgs<Ts...>(input_types, output_types); } template <typename T, typename... Ts> static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray*>::value>::type ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); ParseArgs<Ts...>(input_types, output_types); } #define PARSE_INPUT_BASE(pack_type, onnx_type) \ template <typename T, typename... Ts> \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ input_types.push_back(onnx_type); \ ParseArgs<Ts...>(input_types, output_types); \ } \ template <typename T, typename... Ts> \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const std::optional<pack_type>>::value>::type \ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ input_types.push_back(onnx_type); \ ParseArgs<Ts...>(input_types, output_types); \ } \ template <typename T, typename... Ts> \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ input_types.push_back(onnx_type); \ ParseArgs<Ts...>(input_types, output_types); \ } #define PARSE_INPUT(data_type, onnx_type) \ PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \ PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \ PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \ PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \ PARSE_INPUT_BASE(data_type, onnx_type) #define PARSE_OUTPUT(data_type, onnx_type) \ template <typename T, typename... Ts> \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ output_types.push_back(onnx_type); \ ParseArgs<Ts...>(input_types, output_types); \ } \ template <typename T, typename... Ts> \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ output_types.push_back(onnx_type); \ ParseArgs<Ts...>(input_types, output_types); \ } \ template <typename T, typename... Ts> \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ output_types.push_back(onnx_type); \ ParseArgs<Ts...>(input_types, output_types); \ } #define PARSE_ARGS(data_type, onnx_type) \ PARSE_INPUT(data_type, onnx_type) \ PARSE_OUTPUT(data_type, onnx_type) PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN) PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ) PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ) OrtLiteCustomOp(const char* op_name, const char* execution_provider, ShapeInferFn shape_infer_fn, int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), execution_provider_(execution_provider), shape_infer_fn_(shape_infer_fn), start_ver_(start_ver), end_ver_(end_ver) { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); }; OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); }; OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; }; OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) { auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); return self->input_types_.size(); }; OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) { auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); return self->input_types_[indice]; }; OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) { auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); return self->output_types_.size(); }; OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) { auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); return self->output_types_[indice]; }; OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) { auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL; }; OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) { auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL; }; OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 1; }; OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 1; }; OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::CreateKernelV2 = {}; OrtCustomOp::KernelComputeV2 = {}; OrtCustomOp::KernelCompute = {}; OrtCustomOp::InferOutputShapeFn = {}; OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) { auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); return self->start_ver_; }; OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) { auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); return self->end_ver_; }; OrtCustomOp::GetMayInplace = {}; OrtCustomOp::ReleaseMayInplace = {}; OrtCustomOp::GetAliasMap = {}; OrtCustomOp::ReleaseAliasMap = {}; } const std::string op_name_; const std::string execution_provider_; std::vector<ONNXTensorElementDataType> input_types_; std::vector<ONNXTensorElementDataType> output_types_; ShapeInferFn shape_infer_fn_ = {}; int start_ver_ = 1; int end_ver_ = MAX_CUSTOM_OP_END_VER; void* compute_fn_ = {}; void* compute_fn_return_status_ = {}; }; //////////////////////////// OrtLiteCustomFunc //////////////////////////////// // The struct is to implement function-as-op. // E.g. a function might be defined as: // void Filter(const Ort::Custom::Tensor<float>& floats_in, Ort::Custom::Tensor<float>& floats_out) { ... } // It could be registered this way: // Ort::CustomOpDomain v2_domain{"v2"}; // std::unique_ptr<OrtLiteCustomOp> fil_op_ptr{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)}; // v2_domain.Add(fil_op_ptr.get()); // session_options.Add(v2_domain); // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/". template <typename... Args> struct OrtLiteCustomFunc : public OrtLiteCustomOp { using ComputeFn = void (*)(Args...); using ComputeFnReturnStatus = Status (*)(Args...); using MyType = OrtLiteCustomFunc<Args...>; struct Kernel { size_t num_input_{}; size_t num_output_{}; ComputeFn compute_fn_{}; ComputeFnReturnStatus compute_fn_return_status_{}; std::string ep_{}; }; OrtLiteCustomFunc(const char* op_name, const char* execution_provider, ComputeFn compute_fn, ShapeInferFn shape_infer_fn = {}, int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { compute_fn_ = reinterpret_cast<void*>(compute_fn); ParseArgs<Args...>(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { auto kernel = reinterpret_cast<Kernel*>(op_kernel); std::vector<ArgPtr> args; auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t); }; OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique<Kernel>(); auto me = static_cast<const MyType*>(this_); kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast<const OrtLiteCustomFunc*>(this_); kernel->ep_ = self->execution_provider_; return reinterpret_cast<void*>(kernel.release()); }; OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete reinterpret_cast<Kernel*>(op_kernel); }; if (shape_infer_fn_) { OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_; ShapeInferContext ctx(&GetApi(), ort_ctx); return shape_info_fn(ctx); }; } } OrtLiteCustomFunc(const char* op_name, const char* execution_provider, ComputeFnReturnStatus compute_fn_return_status, ShapeInferFn shape_infer_fn = {}, int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status); ParseArgs<Args...>(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { auto kernel = reinterpret_cast<Kernel*>(op_kernel); std::vector<ArgPtr> args; auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t); }; OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique<Kernel>(); auto me = static_cast<const MyType*>(this_); kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast<const OrtLiteCustomFunc*>(this_); kernel->ep_ = self->execution_provider_; return reinterpret_cast<void*>(kernel.release()); }; OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete reinterpret_cast<Kernel*>(op_kernel); }; if (shape_infer_fn_) { OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_; ShapeInferContext ctx(&GetApi(), ort_ctx); return shape_info_fn(ctx); }; } } }; // struct OrtLiteCustomFunc /////////////////////////// OrtLiteCustomStruct /////////////////////////// // The struct is to implement struct-as-op. // E.g. a struct might be defined as: // struct Merge { // Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...} // void Compute(const Ort::Custom::Tensor<std::string_view>& strings_in, // std::string_view string_in, // Ort::Custom::Tensor<std::string>* strings_out) {...} // bool reverse_ = false; // }; // It could be registered this way: // Ort::CustomOpDomain v2_domain{"v2"}; // std::unique_ptr<OrtLiteCustomOp> mrg_op_ptr{Ort::Custom::CreateLiteCustomOp<Merge>("Merge", "CPUExecutionProvider")}; // v2_domain.Add(mrg_op_ptr.get()); // session_options.Add(v2_domain); // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/". template <typename CustomOp> struct OrtLiteCustomStruct : public OrtLiteCustomOp { template <typename... Args> using CustomComputeFn = void (CustomOp::*)(Args...); template <typename... Args> using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...); using MyType = OrtLiteCustomStruct<CustomOp>; struct Kernel { size_t num_input_{}; size_t num_output_{}; std::unique_ptr<CustomOp> custom_op_; std::string ep_{}; }; OrtLiteCustomStruct(const char* op_name, const char* execution_provider, int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) { SetCompute(&CustomOp::Compute); OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique<Kernel>(); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); kernel->custom_op_ = std::make_unique<CustomOp>(ort_api, info); auto self = static_cast<const OrtLiteCustomStruct*>(this_); kernel->ep_ = self->execution_provider_; return reinterpret_cast<void*>(kernel.release()); }; OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete reinterpret_cast<Kernel*>(op_kernel); }; SetShapeInfer<CustomOp>(0); } template <typename... Args> void SetCompute(CustomComputeFn<Args...>) { ParseArgs<Args...>(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { auto kernel = reinterpret_cast<Kernel*>(op_kernel); ArgPtrs args; auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); }; } template <typename... Args> void SetCompute(CustomComputeFnReturnStatus<Args...>) { ParseArgs<Args...>(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { auto kernel = reinterpret_cast<Kernel*>(op_kernel); ArgPtrs args; auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t); }; } template <typename C> decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) { OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { ShapeInferContext ctx(&GetApi(), ort_ctx); return C::InferOutputShape(ctx); }; return {}; } template <typename C> void SetShapeInfer(...) { OrtCustomOp::InferOutputShapeFn = {}; } }; // struct OrtLiteCustomStruct /////////////////////////// CreateLiteCustomOp //////////////////////////// template <typename... Args> OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, void (*custom_compute_fn)(Args...), Status (*shape_infer_fn)(ShapeInferContext&) = {}, int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc<Args...>; return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release(); } template <typename... Args> OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, Status (*custom_compute_fn_v2)(Args...), Status (*shape_infer_fn)(ShapeInferContext&) = {}, int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc<Args...>; return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release(); } template <typename CustomOp> OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomStruct<CustomOp>; return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release(); } } // namespace Custom } // namespace Ort