diff --git a/example-onnx_mnist/src/ofApp.cpp b/example-onnx_mnist/src/ofApp.cpp index 9194887..f4b19b4 100644 --- a/example-onnx_mnist/src/ofApp.cpp +++ b/example-onnx_mnist/src/ofApp.cpp @@ -47,9 +47,12 @@ public: pix.setFromExternalPixels(mnist2.getInputTensorData(), 28, 28, 1); //mnist->Run(); - mnist2.run(); + auto& result = mnist2.run(); + const float *output_ptr = result.GetTensorMutableData(); mnist_result.resize(10); + + cerr << "API : " << Ort::Global::api_ << endl; } void update() { diff --git a/src/ofxOnnxRuntime.cpp b/src/ofxOnnxRuntime.cpp index e882125..2ca6633 100644 --- a/src/ofxOnnxRuntime.cpp +++ b/src/ofxOnnxRuntime.cpp @@ -66,13 +66,16 @@ namespace ofxOnnxRuntime // 4. output names & output dimms num_outputs = ort_session->GetOutputCount(); output_node_names.resize(num_outputs); + output_node_dims.clear(); + output_values.clear(); for (unsigned int i = 0; i < num_outputs; ++i) { output_node_names[i] = ort_session->GetOutputName(i, allocator); Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i); auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo(); auto output_dims = output_tensor_info.GetShape(); - output_node_dims.push_back(output_dims); + output_node_dims.emplace_back(output_dims); + output_values.emplace_back(nullptr); } } @@ -81,14 +84,14 @@ namespace ofxOnnxRuntime auto input_tensor_ = Ort::Value::CreateTensor( memory_info_handler, input_values_handler.data(), input_tensor_size, input_node_dims.data(), input_node_dims.size()); - auto result = ort_session->Run(Ort::RunOptions{ nullptr }, input_node_names.data(), &input_tensor_, input_node_names.size(), - output_node_names.data(), output_node_names.size()); + ort_session->Run(Ort::RunOptions{ nullptr }, input_node_names.data(), &input_tensor_, input_node_names.size(), + output_node_names.data(), output_values.data(), output_node_names.size()); - if (result.size() == 1) { - return result.front(); + if (output_values.size() == 1) { + return output_values.at(0); } else { - return dummy_tensor_; + return dummy_tensor; } } -} \ No newline at end of file +} diff --git a/src/ofxOnnxRuntime.h b/src/ofxOnnxRuntime.h index a2945b5..e8001f5 100644 --- a/src/ofxOnnxRuntime.h +++ b/src/ofxOnnxRuntime.h @@ -20,6 +20,8 @@ namespace ofxOnnxRuntime class BaseHandler { public: + BaseHandler() {} + void setup(const std::string& onnx_path, const BaseSetting& base_setting = BaseSetting{ INFER_CPU, 0 }); void setup2(const std::string& onnx_path, const Ort::SessionOptions& session_options); @@ -35,10 +37,11 @@ namespace ofxOnnxRuntime std::vector input_node_dims; // 1 input only. std::size_t input_tensor_size = 1; std::vector input_values_handler; - Ort::Value dummy_tensor_{ nullptr }; - Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); std::vector output_node_names; std::vector> output_node_dims; // >=1 outputs + std::vector output_values; + Ort::Value dummy_tensor{nullptr}; int num_outputs = 1; }; -} \ No newline at end of file +}