Browse Source

code formatted

master
Yuya Hanai 3 years ago
parent
commit
a88c07922d
  1. 6
      src/ofxOnnxRuntime.cpp
  2. 6
      src/ofxOnnxRuntime.h

6
src/ofxOnnxRuntime.cpp

@ -66,8 +66,8 @@ namespace ofxOnnxRuntime
// 4. output names & output dimms // 4. output names & output dimms
num_outputs = ort_session->GetOutputCount(); num_outputs = ort_session->GetOutputCount();
output_node_names.resize(num_outputs); output_node_names.resize(num_outputs);
output_node_dims.clear(); output_node_dims.clear();
output_values.clear(); output_values.clear();
for (unsigned int i = 0; i < num_outputs; ++i) for (unsigned int i = 0; i < num_outputs; ++i)
{ {
output_node_names[i] = ort_session->GetOutputName(i, allocator); output_node_names[i] = ort_session->GetOutputName(i, allocator);
@ -75,7 +75,7 @@ namespace ofxOnnxRuntime
auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo(); auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
auto output_dims = output_tensor_info.GetShape(); auto output_dims = output_tensor_info.GetShape();
output_node_dims.emplace_back(output_dims); output_node_dims.emplace_back(output_dims);
output_values.emplace_back(nullptr); output_values.emplace_back(nullptr);
} }
} }

6
src/ofxOnnxRuntime.h

@ -20,7 +20,7 @@ namespace ofxOnnxRuntime
class BaseHandler class BaseHandler
{ {
public: public:
BaseHandler() {} BaseHandler() {}
void setup(const std::string& onnx_path, const BaseSetting& base_setting = BaseSetting{ INFER_CPU, 0 }); void setup(const std::string& onnx_path, const BaseSetting& base_setting = BaseSetting{ INFER_CPU, 0 });
void setup2(const std::string& onnx_path, const Ort::SessionOptions& session_options); void setup2(const std::string& onnx_path, const Ort::SessionOptions& session_options);
@ -40,8 +40,8 @@ namespace ofxOnnxRuntime
Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<const char *> output_node_names; std::vector<const char *> output_node_names;
std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
std::vector<Ort::Value> output_values; std::vector<Ort::Value> output_values;
Ort::Value dummy_tensor{nullptr}; Ort::Value dummy_tensor{ nullptr };
int num_outputs = 1; int num_outputs = 1;
}; };
} }

Loading…
Cancel
Save