From 5e4e3789107ac2cdf4af3975b7b1c126aa6e00e5 Mon Sep 17 00:00:00 2001
From: Yuya Hanai <hanasan.hanasan@gmail.com>
Date: Thu, 20 Jan 2022 21:33:48 +0900
Subject: [PATCH] base handler

---
 README.md                        |   4 +-
 example-onnx_mnist/src/ofApp.cpp | 102 +++++++------------------------
 src/ofxOnnxRuntime.cpp           |  94 ++++++++++++++++++++++++++++
 src/ofxOnnxRuntime.h             |  41 +++++++++++++
 4 files changed, 159 insertions(+), 82 deletions(-)
 create mode 100644 src/ofxOnnxRuntime.cpp

diff --git a/README.md b/README.md
index 0de2c17..429277f 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,6 @@
 
 !['test'](screenshot.png)
 
-
 ## Installation
 - macOS
     - copy `libonnxruntime.1.10.0.dylib` to `/usr/local/lib` 
@@ -27,3 +26,6 @@
 
 ## ToDo
 - check M1 Mac (should work), Linux CPU&GPU
+
+## Reference Implementation
+- I heavily referred [Lite.AI.ToolKit](https://github.com/DefTruth/lite.ai.toolkit) implementation.
diff --git a/example-onnx_mnist/src/ofApp.cpp b/example-onnx_mnist/src/ofApp.cpp
index 278f6e1..9194887 100644
--- a/example-onnx_mnist/src/ofApp.cpp
+++ b/example-onnx_mnist/src/ofApp.cpp
@@ -16,81 +16,10 @@ template <typename T> static void softmax(T &input) {
 	}
 }
 
-// This is the structure to interface with the MNIST model
-// After instantiation, set the input_image_ data to be the 28x28 pixel image of
-// the number to recognize Then call Run() to fill in the results_ data with the
-// probabilities of each result_ holds the index with highest probability (aka
-// the number the model thinks is in the image)
-struct MNIST {
-	MNIST() {
-
-#ifdef _MSC_VER
-		Ort::SessionOptions sf;
-
-#define USE_CUDA
-#define USE_TENSORRT
-
-#ifdef USE_CUDA
-#ifdef USE_TENSORRT
-		sf.AppendExecutionProvider_TensorRT(OrtTensorRTProviderOptions{ 0 });
-#endif
-		sf.AppendExecutionProvider_CUDA(OrtCUDAProviderOptions());
-#endif
-
-		string path = ofToDataPath("mnist-8.onnx", true);
-		std::wstring widestr = std::wstring(path.begin(), path.end());
-		session_ = make_shared<Ort::Session>(env, widestr.c_str(), sf);
-#else
-		// OSX
-		session_ = make_shared<Ort::Session>(
-			env, ofToDataPath("mnist-8.onnx", true).c_str(),
-			Ort::SessionOptions{ nullptr });
-#endif
-
-		auto memory_info =
-			Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
-		input_tensor_ = Ort::Value::CreateTensor<float>(
-			memory_info, input_image_.data(), input_image_.size(),
-			input_shape_.data(), input_shape_.size());
-		output_tensor_ = Ort::Value::CreateTensor<float>(
-			memory_info, results_.data(), results_.size(), output_shape_.data(),
-			output_shape_.size());
-	}
-
-	std::ptrdiff_t Run() {
-		const char *input_names[] = { "Input3" };
-		const char *output_names[] = { "Plus214_Output_0" };
-
-		session_->Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor_, 1,
-			output_names, &output_tensor_, 1);
-		softmax(results_);
-		result_ = std::distance(results_.begin(),
-			std::max_element(results_.begin(), results_.end()));
-		return result_;
-	}
-
-	static constexpr const int width_ = 28;
-	static constexpr const int height_ = 28;
-
-	std::array<float, width_ * height_> input_image_{};
-	std::array<float, 10> results_{};
-	int64_t result_{ 0 };
-
-private:
-	Ort::Env env;
-	shared_ptr<Ort::Session>
-		session_; // {env, (const wchar_t*)ofToDataPath("mnist-8.onnx",
-				  // true).c_str(), Ort::SessionOptions{ nullptr }};
-
-	Ort::Value input_tensor_{ nullptr };
-	std::array<int64_t, 4> input_shape_{ 1, 1, width_, height_ };
-
-	Ort::Value output_tensor_{ nullptr };
-	std::array<int64_t, 2> output_shape_{ 1, 10 };
-};
-
 class ofApp : public ofBaseApp {
-	shared_ptr<MNIST> mnist;
+	ofxOnnxRuntime::BaseHandler mnist2;
+	vector<float> mnist_result;
+
 	ofFbo fbo_render;
 	ofFbo fbo_classification;
 	ofFloatPixels pix;
@@ -102,8 +31,11 @@ public:
 		ofSetVerticalSync(true);
 		ofSetFrameRate(60);
 
-		mnist = make_shared<MNIST>();
-
+#ifdef _MSC_VER
+		mnist2.setup("mnist-8.onnx", ofxOnnxRuntime::BaseSetting{ ofxOnnxRuntime::INFER_TENSORRT });
+#else
+		mnist2.setup("mnist-8.onnx");
+#endif
 		fbo_render.allocate(280, 280, GL_RGB, 0);
 		fbo_render.getTexture().setTextureMinMagFilter(GL_NEAREST, GL_NEAREST);
 		fbo_render.begin();
@@ -111,9 +43,13 @@ public:
 		fbo_render.end();
 		fbo_classification.allocate(28, 28, GL_R32F, 0);
 
-		pix.setFromExternalPixels(&mnist->input_image_.front(), 28, 28, 1);
+		//pix.setFromExternalPixels(&mnist->input_image_.front(), 28, 28, 1);
+		pix.setFromExternalPixels(mnist2.getInputTensorData(), 28, 28, 1);
+
+		//mnist->Run();
+		mnist2.run();
 
-		mnist->Run();
+		mnist_result.resize(10);
 	}
 
 	void update() {
@@ -136,7 +72,10 @@ public:
 				fbo_classification.getHeight());
 			fbo_classification.end();
 			fbo_classification.readToPixels(pix);
-			mnist->Run();
+			auto& result = mnist2.run();
+			const float *output_ptr = result.GetTensorMutableData<float>();
+			memcpy(mnist_result.data(), output_ptr, mnist_result.size() * sizeof(float));
+			softmax(mnist_result);
 			prev_pt = pt;
 			prev_pressed = true;
 		}
@@ -152,14 +91,15 @@ public:
 		fbo_classification.draw(0, 340);
 
 		// render result
+		auto& result = mnist_result;
 		for (int i = 0; i < 10; ++i) {
 			stringstream ss;
 			ss << i << ":" << std::fixed << std::setprecision(3)
-				<< mnist->results_[i];
+				<< mnist_result[i];
 			ofDrawBitmapString(ss.str(), 300, 70 + i * 30);
 			ofPushStyle();
 			ofSetColor(0, 255, 0);
-			ofDrawRectangle(360.0, 55 + i * 30, mnist->results_[i] * 300.0, 20);
+			ofDrawRectangle(360.0, 55 + i * 30, mnist_result[i] * 300.0, 20);
 			ofPopStyle();
 		}
 
diff --git a/src/ofxOnnxRuntime.cpp b/src/ofxOnnxRuntime.cpp
new file mode 100644
index 0000000..e882125
--- /dev/null
+++ b/src/ofxOnnxRuntime.cpp
@@ -0,0 +1,94 @@
+#include "ofxOnnxRuntime.h"
+#include "ofMain.h"
+
+namespace ofxOnnxRuntime
+{
+#ifdef _MSC_VER
+	static std::wstring to_wstring(const std::string &str)
+	{
+		unsigned len = str.size() * 2;
+		setlocale(LC_CTYPE, "");
+		wchar_t *p = new wchar_t[len];
+		mbstowcs(p, str.c_str(), len);
+		std::wstring wstr(p);
+		delete[] p;
+		return wstr;
+	}
+#endif
+
+	void BaseHandler::setup(const std::string & onnx_path, const BaseSetting & base_setting)
+	{
+		Ort::SessionOptions session_options;
+		if (base_setting.infer_type == INFER_TENSORRT) {
+			OrtTensorRTProviderOptions op;
+			memset(&op, 0, sizeof(op));
+			op.device_id = base_setting.device_id;
+			op.trt_fp16_enable = 1;
+			op.trt_engine_cache_enable = 1;
+			std::string path = ofToDataPath(onnx_path, true);
+			ofStringReplace(path, ".onnx", "_trt_cache");
+			op.trt_engine_cache_path = path.c_str();
+			session_options.AppendExecutionProvider_TensorRT(op);
+		}
+		if (base_setting.infer_type == INFER_CUDA || base_setting.infer_type == INFER_TENSORRT) {
+			OrtCUDAProviderOptions op;
+			op.device_id = base_setting.device_id;
+			session_options.AppendExecutionProvider_CUDA(op);
+		}
+		this->setup2(onnx_path, session_options);
+	}
+
+	void BaseHandler::setup2(const std::string & onnx_path, const Ort::SessionOptions & session_options)
+	{
+		std::string path = ofToDataPath(onnx_path, true);
+#ifdef _MSC_VER
+		ort_session = std::make_shared<Ort::Session>(ort_env, to_wstring(path).c_str(), session_options);
+#else
+		ort_session = std::make_shared<Ort::Session>(ort_env, path.c_str(), session_options);
+#endif
+
+		Ort::AllocatorWithDefaultOptions allocator;
+		
+		// 2. input name & input dims
+		auto* input_name = ort_session->GetInputName(0, allocator);
+		input_node_names.resize(1);
+		input_node_names[0] = input_name;
+		
+		// 3. type info.
+		Ort::TypeInfo type_info = ort_session->GetInputTypeInfo(0);
+		auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
+		input_tensor_size = 1;
+		input_node_dims = tensor_info.GetShape();
+		for (unsigned int i = 0; i < input_node_dims.size(); ++i)
+			input_tensor_size *= input_node_dims.at(i);
+		input_values_handler.resize(input_tensor_size);
+
+		// 4. output names & output dimms
+		num_outputs = ort_session->GetOutputCount();
+		output_node_names.resize(num_outputs);
+		for (unsigned int i = 0; i < num_outputs; ++i)
+		{
+			output_node_names[i] = ort_session->GetOutputName(i, allocator);
+			Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);
+			auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
+			auto output_dims = output_tensor_info.GetShape();
+			output_node_dims.push_back(output_dims);
+		}
+	}
+
+	Ort::Value& BaseHandler::run()
+	{
+		auto input_tensor_ = Ort::Value::CreateTensor<float>(
+			memory_info_handler, input_values_handler.data(), input_tensor_size,
+			input_node_dims.data(), input_node_dims.size());
+		auto result = ort_session->Run(Ort::RunOptions{ nullptr }, input_node_names.data(), &input_tensor_, input_node_names.size(),
+			output_node_names.data(), output_node_names.size());
+
+		if (result.size() == 1) {
+			return result.front();
+		}
+		else {
+			return dummy_tensor_;
+		}
+	}
+}
\ No newline at end of file
diff --git a/src/ofxOnnxRuntime.h b/src/ofxOnnxRuntime.h
index ff37e29..a2945b5 100644
--- a/src/ofxOnnxRuntime.h
+++ b/src/ofxOnnxRuntime.h
@@ -1,3 +1,44 @@
 #pragma once
 
 #include <onnxruntime_cxx_api.h>
+
+namespace ofxOnnxRuntime
+{
+	enum InferType
+	{
+		INFER_CPU = 0,
+		INFER_CUDA,
+		INFER_TENSORRT
+	};
+
+	struct BaseSetting
+	{
+		InferType infer_type;
+		int device_id;
+	};
+
+	class BaseHandler
+	{
+	public:
+		void setup(const std::string& onnx_path, const BaseSetting& base_setting = BaseSetting{ INFER_CPU, 0 });
+		void setup2(const std::string& onnx_path, const Ort::SessionOptions& session_options);
+
+		Ort::Value& run();
+
+		float* getInputTensorData() {
+			return this->input_values_handler.data();
+		}
+	protected:
+		Ort::Env ort_env;
+		std::shared_ptr<Ort::Session> ort_session;
+		std::vector<const char *> input_node_names;
+		std::vector<int64_t> input_node_dims; // 1 input only.
+		std::size_t input_tensor_size = 1;
+		std::vector<float> input_values_handler;
+		Ort::Value dummy_tensor_{ nullptr };
+		Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+		std::vector<const char *> output_node_names;
+		std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
+		int num_outputs = 1;
+	};
+}
\ No newline at end of file