Browse Source

fix on mac build

master
Yuya Hanai 3 years ago
parent
commit
7d73bf78c2
  1. 5
      example-onnx_mnist/src/ofApp.cpp
  2. 15
      src/ofxOnnxRuntime.cpp
  3. 7
      src/ofxOnnxRuntime.h

5
example-onnx_mnist/src/ofApp.cpp

@ -47,9 +47,12 @@ public:
pix.setFromExternalPixels(mnist2.getInputTensorData(), 28, 28, 1); pix.setFromExternalPixels(mnist2.getInputTensorData(), 28, 28, 1);
//mnist->Run(); //mnist->Run();
mnist2.run(); auto& result = mnist2.run();
const float *output_ptr = result.GetTensorMutableData<float>();
mnist_result.resize(10); mnist_result.resize(10);
cerr << "API : " << Ort::Global<void>::api_ << endl;
} }
void update() { void update() {

15
src/ofxOnnxRuntime.cpp

@ -66,13 +66,16 @@ namespace ofxOnnxRuntime
// 4. output names & output dimms // 4. output names & output dimms
num_outputs = ort_session->GetOutputCount(); num_outputs = ort_session->GetOutputCount();
output_node_names.resize(num_outputs); output_node_names.resize(num_outputs);
output_node_dims.clear();
output_values.clear();
for (unsigned int i = 0; i < num_outputs; ++i) for (unsigned int i = 0; i < num_outputs; ++i)
{ {
output_node_names[i] = ort_session->GetOutputName(i, allocator); output_node_names[i] = ort_session->GetOutputName(i, allocator);
Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i); Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);
auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo(); auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
auto output_dims = output_tensor_info.GetShape(); auto output_dims = output_tensor_info.GetShape();
output_node_dims.push_back(output_dims); output_node_dims.emplace_back(output_dims);
output_values.emplace_back(nullptr);
} }
} }
@ -81,14 +84,14 @@ namespace ofxOnnxRuntime
auto input_tensor_ = Ort::Value::CreateTensor<float>( auto input_tensor_ = Ort::Value::CreateTensor<float>(
memory_info_handler, input_values_handler.data(), input_tensor_size, memory_info_handler, input_values_handler.data(), input_tensor_size,
input_node_dims.data(), input_node_dims.size()); input_node_dims.data(), input_node_dims.size());
auto result = ort_session->Run(Ort::RunOptions{ nullptr }, input_node_names.data(), &input_tensor_, input_node_names.size(), ort_session->Run(Ort::RunOptions{ nullptr }, input_node_names.data(), &input_tensor_, input_node_names.size(),
output_node_names.data(), output_node_names.size()); output_node_names.data(), output_values.data(), output_node_names.size());
if (result.size() == 1) { if (output_values.size() == 1) {
return result.front(); return output_values.at(0);
} }
else { else {
return dummy_tensor_; return dummy_tensor;
} }
} }
} }

7
src/ofxOnnxRuntime.h

@ -20,6 +20,8 @@ namespace ofxOnnxRuntime
class BaseHandler class BaseHandler
{ {
public: public:
BaseHandler() {}
void setup(const std::string& onnx_path, const BaseSetting& base_setting = BaseSetting{ INFER_CPU, 0 }); void setup(const std::string& onnx_path, const BaseSetting& base_setting = BaseSetting{ INFER_CPU, 0 });
void setup2(const std::string& onnx_path, const Ort::SessionOptions& session_options); void setup2(const std::string& onnx_path, const Ort::SessionOptions& session_options);
@ -35,10 +37,11 @@ namespace ofxOnnxRuntime
std::vector<int64_t> input_node_dims; // 1 input only. std::vector<int64_t> input_node_dims; // 1 input only.
std::size_t input_tensor_size = 1; std::size_t input_tensor_size = 1;
std::vector<float> input_values_handler; std::vector<float> input_values_handler;
Ort::Value dummy_tensor_{ nullptr }; Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<const char *> output_node_names; std::vector<const char *> output_node_names;
std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
std::vector<Ort::Value> output_values;
Ort::Value dummy_tensor{nullptr};
int num_outputs = 1; int num_outputs = 1;
}; };
} }
Loading…
Cancel
Save