Browse Source

input dims & product calc

mingw
cailean 2 days ago
parent
commit
4bad65e13b
  1. 22
      src/ofxOnnxRuntime.cpp
  2. 7
      src/ofxOnnxRuntime.h

22
src/ofxOnnxRuntime.cpp

@ -19,6 +19,8 @@ namespace ofxOnnxRuntime
// Store data types // Store data types
this->input_dtype = base_setting.input_dtype; this->input_dtype = base_setting.input_dtype;
this->output_dtype = base_setting.output_dtype; this->output_dtype = base_setting.output_dtype;
this->inputWidth = base_setting.width;
this->inputHeight = base_setting.height;
Ort::SessionOptions session_options; Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(1); session_options.SetIntraOpNumThreads(1);
@ -214,8 +216,8 @@ namespace ofxOnnxRuntime
// Resize to 192x192 if needed // Resize to 192x192 if needed
cv::Mat resizedImage; cv::Mat resizedImage;
if (width != 192 || height != 192) { if (width != inputWidth || height != inputHeight) {
cv::resize(cvImage, resizedImage, cv::Size(192, 192)); cv::resize(cvImage, resizedImage, cv::Size(inputWidth, inputHeight));
} else { } else {
resizedImage = cvImage; resizedImage = cvImage;
} }
@ -224,9 +226,9 @@ namespace ofxOnnxRuntime
cv::Mat floatImage; cv::Mat floatImage;
resizedImage.convertTo(floatImage, CV_32F, 1.0/255.0); resizedImage.convertTo(floatImage, CV_32F, 1.0/255.0);
// Calculate offset in destination array // Calculate offset in destination array NEED TO CALC PRODUCT
size_t elementsPerImage = input_node_dims[1] * input_node_dims[2] * input_node_dims[3]; int elementsPerImage = CalculateProduct(input_node_dims);
size_t startPos = idx * elementsPerImage; int startPos = idx * elementsPerImage;
// Copy directly // Copy directly
float* floatPtr = reinterpret_cast<float*>(floatImage.data); float* floatPtr = reinterpret_cast<float*>(floatImage.data);
@ -244,8 +246,8 @@ namespace ofxOnnxRuntime
// Resize to 192x192 if needed // Resize to 192x192 if needed
cv::Mat resizedImage; cv::Mat resizedImage;
if (width != 192 || height != 192) { if (width != inputWidth || height != inputHeight) {
cv::resize(cvImage, resizedImage, cv::Size(192, 192)); cv::resize(cvImage, resizedImage, cv::Size(inputWidth, inputHeight));
} else { } else {
resizedImage = cvImage; resizedImage = cvImage;
} }
@ -254,9 +256,9 @@ namespace ofxOnnxRuntime
cv::Mat intImage; cv::Mat intImage;
resizedImage.convertTo(intImage, CV_32SC3); resizedImage.convertTo(intImage, CV_32SC3);
// Calculate offset in destination array // Calculate offset in destination array CALC PRODUCT
size_t elementsPerImage = 192 * 192 * 3; int elementsPerImage = CalculateProduct(input_node_dims);
size_t startPos = idx * elementsPerImage; int startPos = idx * elementsPerImage;
// Copy directly // Copy directly
int32_t* intPtr = reinterpret_cast<int32_t*>(intImage.data); int32_t* intPtr = reinterpret_cast<int32_t*>(intImage.data);

7
src/ofxOnnxRuntime.h

@ -24,6 +24,8 @@ namespace ofxOnnxRuntime
int device_id; int device_id;
ModelDataType input_dtype = FLOAT32; ModelDataType input_dtype = FLOAT32;
ModelDataType output_dtype = FLOAT32; ModelDataType output_dtype = FLOAT32;
int width;
int height;
}; };
class BaseHandler class BaseHandler
@ -31,7 +33,7 @@ namespace ofxOnnxRuntime
public: public:
BaseHandler() {} BaseHandler() {}
void setup(const std::string& onnx_path, const BaseSetting& base_setting = BaseSetting{ INFER_CPU, 0, FLOAT32, FLOAT32 }, const int& batch_size = 1, const bool debug = false, const bool timestamp = false); void setup(const std::string& onnx_path, const BaseSetting& base_setting = BaseSetting{ INFER_CPU, 0, FLOAT32, FLOAT32, 256, 256 }, const int& batch_size = 1, const bool debug = false, const bool timestamp = false);
void setup2(const std::string& onnx_path, const Ort::SessionOptions& session_options); void setup2(const std::string& onnx_path, const Ort::SessionOptions& session_options);
void setNames(); void setNames();
void setInputs(std::vector<ofImage*>& input_imgs); void setInputs(std::vector<ofImage*>& input_imgs);
@ -75,5 +77,8 @@ namespace ofxOnnxRuntime
ModelDataType input_dtype; ModelDataType input_dtype;
ModelDataType output_dtype; ModelDataType output_dtype;
int inputWidth;
int inputHeight;
}; };
} }

Loading…
Cancel
Save