diff --git a/src/ofxOnnxRuntime.cpp b/src/ofxOnnxRuntime.cpp index 607b6ec..a4451a2 100644 --- a/src/ofxOnnxRuntime.cpp +++ b/src/ofxOnnxRuntime.cpp @@ -19,6 +19,8 @@ namespace ofxOnnxRuntime // Store data types this->input_dtype = base_setting.input_dtype; this->output_dtype = base_setting.output_dtype; + this->inputWidth = base_setting.width; + this->inputHeight = base_setting.height; Ort::SessionOptions session_options; session_options.SetIntraOpNumThreads(1); @@ -214,8 +216,8 @@ namespace ofxOnnxRuntime // Resize to 192x192 if needed cv::Mat resizedImage; - if (width != 192 || height != 192) { - cv::resize(cvImage, resizedImage, cv::Size(192, 192)); + if (width != inputWidth || height != inputHeight) { + cv::resize(cvImage, resizedImage, cv::Size(inputWidth, inputHeight)); } else { resizedImage = cvImage; } @@ -224,9 +226,9 @@ namespace ofxOnnxRuntime cv::Mat floatImage; resizedImage.convertTo(floatImage, CV_32F, 1.0/255.0); - // Calculate offset in destination array - size_t elementsPerImage = input_node_dims[1] * input_node_dims[2] * input_node_dims[3]; - size_t startPos = idx * elementsPerImage; + // Calculate offset in destination array NEED TO CALC PRODUCT + int elementsPerImage = CalculateProduct(input_node_dims); + int startPos = idx * elementsPerImage; // Copy directly float* floatPtr = reinterpret_cast(floatImage.data); @@ -244,8 +246,8 @@ namespace ofxOnnxRuntime // Resize to 192x192 if needed cv::Mat resizedImage; - if (width != 192 || height != 192) { - cv::resize(cvImage, resizedImage, cv::Size(192, 192)); + if (width != inputWidth || height != inputHeight) { + cv::resize(cvImage, resizedImage, cv::Size(inputWidth, inputHeight)); } else { resizedImage = cvImage; } @@ -254,9 +256,9 @@ namespace ofxOnnxRuntime cv::Mat intImage; resizedImage.convertTo(intImage, CV_32SC3); - // Calculate offset in destination array - size_t elementsPerImage = 192 * 192 * 3; - size_t startPos = idx * elementsPerImage; + // Calculate offset in destination array CALC PRODUCT + int elementsPerImage = CalculateProduct(input_node_dims); + int startPos = idx * elementsPerImage; // Copy directly int32_t* intPtr = reinterpret_cast(intImage.data); diff --git a/src/ofxOnnxRuntime.h b/src/ofxOnnxRuntime.h index c62fc51..774a425 100644 --- a/src/ofxOnnxRuntime.h +++ b/src/ofxOnnxRuntime.h @@ -24,6 +24,8 @@ namespace ofxOnnxRuntime int device_id; ModelDataType input_dtype = FLOAT32; ModelDataType output_dtype = FLOAT32; + int width; + int height; }; class BaseHandler @@ -31,7 +33,7 @@ namespace ofxOnnxRuntime public: BaseHandler() {} - void setup(const std::string& onnx_path, const BaseSetting& base_setting = BaseSetting{ INFER_CPU, 0, FLOAT32, FLOAT32 }, const int& batch_size = 1, const bool debug = false, const bool timestamp = false); + void setup(const std::string& onnx_path, const BaseSetting& base_setting = BaseSetting{ INFER_CPU, 0, FLOAT32, FLOAT32, 256, 256 }, const int& batch_size = 1, const bool debug = false, const bool timestamp = false); void setup2(const std::string& onnx_path, const Ort::SessionOptions& session_options); void setNames(); void setInputs(std::vector& input_imgs); @@ -75,5 +77,8 @@ namespace ofxOnnxRuntime ModelDataType input_dtype; ModelDataType output_dtype; + + int inputWidth; + int inputHeight; }; }