Browse Source

code formatted

master
Yuya Hanai 3 years ago
parent
commit
a88c07922d
  1. 10
      src/ofxOnnxRuntime.cpp
  2. 8
      src/ofxOnnxRuntime.h

10
src/ofxOnnxRuntime.cpp

@ -48,12 +48,12 @@ namespace ofxOnnxRuntime
#endif #endif
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
// 2. input name & input dims // 2. input name & input dims
auto* input_name = ort_session->GetInputName(0, allocator); auto* input_name = ort_session->GetInputName(0, allocator);
input_node_names.resize(1); input_node_names.resize(1);
input_node_names[0] = input_name; input_node_names[0] = input_name;
// 3. type info. // 3. type info.
Ort::TypeInfo type_info = ort_session->GetInputTypeInfo(0); Ort::TypeInfo type_info = ort_session->GetInputTypeInfo(0);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
@ -66,8 +66,8 @@ namespace ofxOnnxRuntime
// 4. output names & output dimms // 4. output names & output dimms
num_outputs = ort_session->GetOutputCount(); num_outputs = ort_session->GetOutputCount();
output_node_names.resize(num_outputs); output_node_names.resize(num_outputs);
output_node_dims.clear(); output_node_dims.clear();
output_values.clear(); output_values.clear();
for (unsigned int i = 0; i < num_outputs; ++i) for (unsigned int i = 0; i < num_outputs; ++i)
{ {
output_node_names[i] = ort_session->GetOutputName(i, allocator); output_node_names[i] = ort_session->GetOutputName(i, allocator);
@ -75,7 +75,7 @@ namespace ofxOnnxRuntime
auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo(); auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
auto output_dims = output_tensor_info.GetShape(); auto output_dims = output_tensor_info.GetShape();
output_node_dims.emplace_back(output_dims); output_node_dims.emplace_back(output_dims);
output_values.emplace_back(nullptr); output_values.emplace_back(nullptr);
} }
} }

8
src/ofxOnnxRuntime.h

@ -20,8 +20,8 @@ namespace ofxOnnxRuntime
class BaseHandler class BaseHandler
{ {
public: public:
BaseHandler() {} BaseHandler() {}
void setup(const std::string& onnx_path, const BaseSetting& base_setting = BaseSetting{ INFER_CPU, 0 }); void setup(const std::string& onnx_path, const BaseSetting& base_setting = BaseSetting{ INFER_CPU, 0 });
void setup2(const std::string& onnx_path, const Ort::SessionOptions& session_options); void setup2(const std::string& onnx_path, const Ort::SessionOptions& session_options);
@ -40,8 +40,8 @@ namespace ofxOnnxRuntime
Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<const char *> output_node_names; std::vector<const char *> output_node_names;
std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
std::vector<Ort::Value> output_values; std::vector<Ort::Value> output_values;
Ort::Value dummy_tensor{nullptr}; Ort::Value dummy_tensor{ nullptr };
int num_outputs = 1; int num_outputs = 1;
}; };
} }

Loading…
Cancel
Save