diff --git a/.gitignore b/.gitignore index 5f8f815..e634954 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ /bin/data/images /bin/data/recordings /bin/data/models +/bin/data/dataset /bin/libs/ ######### diff --git a/bin/data/debug_binaryMat.png b/bin/data/debug_binaryMat.png index 40b298c..72724be 100644 Binary files a/bin/data/debug_binaryMat.png and b/bin/data/debug_binaryMat.png differ diff --git a/bin/data/new_tree.bin b/bin/data/new_tree.bin index e00ee95..6336ded 100644 Binary files a/bin/data/new_tree.bin and b/bin/data/new_tree.bin differ diff --git a/bin/data/shaders/espDepth.frag b/bin/data/shaders/espDepth.frag new file mode 100644 index 0000000..8d4d229 --- /dev/null +++ b/bin/data/shaders/espDepth.frag @@ -0,0 +1,38 @@ +#version 150 +precision highp float; + +uniform sampler2DRect tex0; // Your FBO texture (RGB FBO) +uniform float texW; // FBO texture width +uniform float texH; // FBO texture height + +out vec4 fragColor; // Output color +in vec2 varyingtexcoord; + +void main() { + float threshold = 0.7; + float depth = texture(tex0, varyingtexcoord).r; + + // Determine which quadrant we're in + int quadrantX = int(varyingtexcoord.x / (256.0 / 2.0)); + int quadrantY = int(varyingtexcoord.y / (336.0 / 2.0)); + int quadrant = quadrantY * 2 + quadrantX; + + vec3 baseColor; + vec3 mixColor = vec3(0.0); // White for mixing in all quadrants + + // Choose base color based on quadrant + if (quadrant == 0) { + baseColor = vec3(1.0, 0.0, 0.0); // Red for bottom-left + } else if (quadrant == 1) { + baseColor = vec3(0.0, 1.0, 0.0); // Green for bottom-right + } else if (quadrant == 2) { + baseColor = vec3(0.0, 0.0, 1.0); // Blue for top-left + } else { + baseColor = vec3(1.0, 1.0, 1.0); // White for top-right + } + + // Mix the base color with white based on the depth + vec3 color = mix(mixColor, baseColor, depth); + + fragColor = vec4(color, 1.0); +} \ No newline at end of file diff --git a/bin/data/shaders/espDepth.vert b/bin/data/shaders/espDepth.vert new file mode 100644 index 0000000..bf5c2b0 --- /dev/null +++ b/bin/data/shaders/espDepth.vert @@ -0,0 +1,13 @@ +#version 150 + +uniform mat4 modelViewProjectionMatrix; + +in vec4 position; +in vec2 texcoord; + +out vec2 varyingtexcoord; + +void main(){ + varyingtexcoord = texcoord; + gl_Position = modelViewProjectionMatrix * position; +} \ No newline at end of file diff --git a/bin/image-to-mesh b/bin/image-to-mesh index f8cfcf6..3dbe4b4 100755 Binary files a/bin/image-to-mesh and b/bin/image-to-mesh differ diff --git a/src/Bullet.cpp b/src/Bullet.cpp index 1f69f46..640fc53 100644 --- a/src/Bullet.cpp +++ b/src/Bullet.cpp @@ -26,7 +26,7 @@ void Bullet::setup(vector& _nodes){ shader.load("shaders/vertex_snap"); /* assign nodes */ - nodes = _nodes; + //nodes = _nodes; std::cout << workerThreads.size() << std::endl; } @@ -128,19 +128,18 @@ void Bullet::draw(){ now = false; } -void Bullet::addMesh(ofMesh _mesh, ofMesh _simple_mesh, ofTexture _tex){ +void Bullet::addMesh(ofMesh _mesh, ofMesh _simple_mesh, Node& _node){ std::lock_guard lock(shapeMutex); - Node n; - n.tex = _tex; + // n.tex = _tex; float rand = ofRandom(0.01, 0.02); glm::vec3 random_szie(rand, rand, rand); - n.scale = random_szie; + _node.scale = random_szie; ofQuaternion startRot = ofQuaternion(0., 0., 0., PI); - ofVec3f target_location = ofVec3f( ofRandom(0, 0), ofRandom(0, 0), -5 ); - ofVec3f start_location = ofVec3f( ofRandom(-300, 300), ofRandom(-300, 300), -5 ); + // ofVec3f target_location = ofVec3f( ofRandom(0, 0), ofRandom(0, 0), -5 ); + ofVec3f start_location = ofVec3f( ofRandom(-3000, 3000), ofRandom(-3000, 3000) -5 ); ofxBulletCustomShape* s = new ofxBulletCustomShape(); - s->addMesh(_simple_mesh, n.scale * 1.4, true); + s->addMesh(_simple_mesh, _node.scale * 1.4, true); s->create( world.world, start_location, startRot, 3.); s->add(); s->getRigidBody()->setAngularFactor(btVector3(0, 0, 1)); @@ -150,15 +149,15 @@ void Bullet::addMesh(ofMesh _mesh, ofMesh _simple_mesh, ofTexture _tex){ s->getRigidBody()->setRestitution(btScalar(0.5)); s->getRigidBody()->setFriction(btScalar(1.0)); - positions.push_back(target_location); + //positions.push_back(target_location); // Set how the col mesh is drawn! _simple_mesh.setMode(OF_PRIMITIVE_LINES); - n.collider = s; - n.mesh = _mesh; - n.col_mesh = _simple_mesh; - nodes.push_back(n); + _node.collider = s; + _node.mesh = _mesh; + _node.col_mesh = _simple_mesh; + // nodes.push_back(n); } void Bullet::workerThreadFunction(int threadId) { @@ -209,24 +208,25 @@ void Bullet::updateShapeBatch(size_t start, size_t end) { std::lock_guard lock(shapeMutex); glm::vec3 pos = nodes[i].collider->getPosition(); - glm::vec3 direction = glm::vec3(0, 0, -5) - pos; + glm::vec3 target_pos(nodes[i].tsne_position.x, nodes[i].tsne_position.y, -5); + glm::vec3 direction = target_pos - pos; float dist_sq = glm::length(direction); glm::vec3 norm_dir = glm::normalize(direction); nodes[i].collider->applyCentralForce(1.0f * norm_dir); // Apply repulsion force if needed - if (shouldApplyRepulsion && i != repulsionNodeIndex) { - const Node& repulsionNode = nodes[repulsionNodeIndex]; - glm::vec3 repulsionDir = pos - repulsionNode.collider->getPosition(); - float repulsionDist = glm::length(repulsionDir); - - if (repulsionDist < repulsionRadius && repulsionDist > 0) { - repulsionDir = glm::normalize(repulsionDir); - float forceMagnitude = repulsionStrength * (1.0f - repulsionDist / repulsionRadius); - glm::vec3 repulsionForce = repulsionDir * forceMagnitude; - nodes[i].collider->applyCentralForce(repulsionForce * 100.0f); - } - } + // if (shouldApplyRepulsion && i != repulsionNodeIndex) { + // const Node& repulsionNode = nodes[repulsionNodeIndex]; + // glm::vec3 repulsionDir = pos - repulsionNode.collider->getPosition(); + // float repulsionDist = glm::length(repulsionDir); + + // if (repulsionDist < repulsionRadius && repulsionDist > 0) { + // repulsionDir = glm::normalize(repulsionDir); + // float forceMagnitude = repulsionStrength * (1.0f - repulsionDist / repulsionRadius); + // glm::vec3 repulsionForce = repulsionDir * forceMagnitude; + // nodes[i].collider->applyCentralForce(repulsionForce * 100.0f); + // } + // } } } @@ -267,4 +267,12 @@ bool Bullet::checkNodeVisibility(const Node& n){ } else { return false; } +} + +glm::vec3 Bullet::getCameraPosition(){ + return camera.getPosition(); +} + +void Bullet::setNodes(vector& _nodes){ + nodes = _nodes; } \ No newline at end of file diff --git a/src/Bullet.h b/src/Bullet.h index 22b335d..9b608b9 100644 --- a/src/Bullet.h +++ b/src/Bullet.h @@ -22,12 +22,14 @@ class Bullet{ void setup(vector& _nodes); void update(); void draw(); - void addMesh(ofMesh _mesh, ofMesh _simple_mesh, ofTexture _tex); + void addMesh(ofMesh _mesh, ofMesh _simple_mesh, Node& _node); float easeInOutCubic(float t); float calculateGridSize(float zoom); void setNewCameraEndpoint(); glm::vec4 getCameraBounds(const ofCamera& cam); bool checkNodeVisibility(const Node& n); + glm::vec3 getCameraPosition(); + void setNodes(vector& _nodes); ofxBulletWorldRigid world; vector bounds; ofxBulletCustomShape* boundsShape; diff --git a/src/ofApp.cpp b/src/ofApp.cpp index 1190429..b45c7c3 100644 --- a/src/ofApp.cpp +++ b/src/ofApp.cpp @@ -11,6 +11,7 @@ void ofApp::setup(){ model_esp_out_fbo.allocate(128 * 2, 168 * 2, GL_RGB); /* k-d image comp (4-images) */ + kd_out.allocate(128 * 2, 168 * 2, GL_RGB); esp_comp_fbo.allocate(128 * 2, 168 * 2, GL_RGB); /* input images for model */ @@ -34,11 +35,12 @@ void ofApp::setup(){ /* load */ shaders.load("shaders/dither"); + esp_shader.load("shaders/espDepth"); ORTCHAR_T* modelPath = "/home/cailean/Desktop/openframeworks/of_v0.12.0_linux64gcc6_release/apps/myApps/image-to-mesh/bin/data/models/depth_anything_v2_vits.onnx"; /* setup */ - //bullet.setup(nodes); + bullet.setup(nodes); depth_onnx.Setup(modelPath, true, true); depth_onnx_esp.Setup(modelPath, true, true); @@ -46,52 +48,31 @@ void ofApp::setup(){ depth_thread.setup(&model_image, &model_outptut_fbo, &depth_onnx); depth_esp.setup(&model_image_esp, &model_esp_out_fbo, &depth_onnx_esp); - - /* mesh generation test */ - std::string path = "images"; - ofDirectory dir(path); - dir.allowExt("png"); - dir.listDir(); - ofLoadImage(bayer, "images/bayer.png"); bayer.setTextureWrap(GL_REPEAT, GL_REPEAT); - // Loop through files and load images - for (size_t i = 0; i < dir.size(); i++) { - std::string filePath = dir.getPath(i); - ofImage img; - if (img.load(filePath)) { // Load the image - images.push_back(img); // Add the image to the list - } else { - ofLog() << "Failed to load: " << filePath; // Log if failed to load - } - } - - - - // vector mesh_list; + createNodes("data/json/embeddings.json"); - // for(auto& img : images){ - // mesh_list = mesh_generator.generateSimplifiedMesh(img); - // bullet.addMesh(mesh_list[0], mesh_list[1], img.getTexture()); - // mesh_list.clear(); - // } + buildMeshes(); - last_updated_time = ofGetElapsedTimef(); + buildKDTree(); - createNodes("data/json/embeddings.json"); + bullet.setNodes(nodes); server = std::make_unique(6762, embed, vp_tree, nodes, false, "192.168.0.253", 2000, "search"); server->start(); + + last_updated_time = ofGetElapsedTimef(); } //-------------------------------------------------------------- void ofApp::update(){ - server->update(model_esp_out_fbo); + server->update(esp_comp_fbo); float current_time = ofGetElapsedTimef(); if(current_time - last_updated_time >= 3){ + buildKDTree(); getNearestImages(); last_updated_time = current_time; } @@ -120,7 +101,7 @@ void ofApp::update(){ std::cout << "Model did not run" << std::endl; } - //bullet.update(); + bullet.update(); } //-------------------------------------------------------------- @@ -129,7 +110,7 @@ void ofApp::draw(){ ofPushStyle(); map_fbo.begin(); ofClear(ofColor::grey); - //bullet.draw(); + bullet.draw(); map_fbo.end(); ofPopStyle(); @@ -152,7 +133,17 @@ void ofApp::draw(){ comp_fbo.draw(0,0); model_outptut_fbo.draw(ofGetWindowWidth() / 2, 0); + + + esp_comp_fbo.begin(); + + esp_shader.begin(); + esp_shader.setUniformTexture("tex0", model_esp_out_fbo.getTexture(), 0); model_esp_out_fbo.draw(0, 0); + esp_comp_fbo.end(); + esp_shader.end(); + + esp_comp_fbo.draw(0, 0); //server->print(); //shader_fbo.draw(0, 0); @@ -182,25 +173,28 @@ void ofApp::draw(){ /* creates an fbo with four cropped images, to preprare for model input */ void ofApp::getNearestImages(){ - esp_comp_fbo.begin(); + glm::vec3 cam_position = bullet.getCameraPosition(); + queryKD(cam_position, 4); + + kd_out.begin(); ofClear(255, 255, 255, 0); - ofImage random_image; + ofImage selected_image; int imageIndex = 0; for(auto& img : esp_images){ - random_image = images[ofRandom(images.size() - 1)]; + selected_image = nodes[kd_result[imageIndex].second].img; // Calculate the scaling factor - float widthRatio = 128.0f / random_image.getWidth(); - float heightRatio = 168.0f / random_image.getHeight(); + float widthRatio = 128.0f / selected_image.getWidth(); + float heightRatio = 168.0f / selected_image.getHeight(); float scale = std::max(widthRatio, heightRatio); // Calculate new dimensions - int newWidth = std::ceil(random_image.getWidth() * scale); - int newHeight = std::ceil(random_image.getHeight() * scale); + int newWidth = std::ceil(selected_image.getWidth() * scale); + int newHeight = std::ceil(selected_image.getHeight() * scale); // Resize the image - random_image.resize(newWidth, newHeight); + selected_image.resize(newWidth, newHeight); // Calculate the crop position to center the image int cropX = (newWidth - 128) / 2; @@ -215,15 +209,15 @@ void ofApp::getNearestImages(){ int drawY = (imageIndex / 2) * 168; // Draw the resized and cropped image - random_image.drawSubsection(drawX, drawY, cropWidth, cropHeight, cropX, cropY); + selected_image.drawSubsection(drawX, drawY, cropWidth, cropHeight, cropX, cropY); imageIndex++; if (imageIndex >= 4) break; // Stop after drawing 4 images } - esp_comp_fbo.end(); + kd_out.end(); - esp_comp_fbo.readToPixels(esp_comp_pixels); + kd_out.readToPixels(esp_comp_pixels); model_image_esp.setFromPixels(esp_comp_pixels); } @@ -244,6 +238,7 @@ void ofApp::createNodes(std::string json_path){ if(j.contains("vector") && j["vector"].is_array()){ Node n; n.img.load(j["image"]); + n.tex = n.img.getTexture(); std::vector t_embedding; for (const auto& value: j["vector"]){ @@ -278,9 +273,8 @@ void ofApp::createNodes(std::string json_path){ for(size_t i = 0; i < tsne_points.size(); i++){ const auto& vec = tsne_points[i]; auto& n = nodes[i]; - n.tsne_position = (glm::vec3(vec[0] * tsne_scale, vec[1] * tsne_scale, -5.0f)); + n.tsne_position = (glm::vec3(((vec[0] * 2) - 1) * tsne_scale, ((vec[1] * 2) - 1) * tsne_scale, -5.0f)); } - /* vp-test */ // auto queries = server->generateRandomVectors(10, 7); @@ -330,3 +324,55 @@ void ofApp::exit(){ server->close(); } +void ofApp::buildMeshes(){ + vector mesh_list; + ofLogNotice() << "building meshes: " << nodes.size(); + for(auto& n : nodes){ + mesh_list = mesh_generator.generateSimplifiedMesh(n.img); + bullet.addMesh(mesh_list[0], mesh_list[1], n); + mesh_list.clear(); + } + ofLogNotice() << "finished building meshes"; +} + +void ofApp::buildKDTree(){ + pointVec kd_points; + + for(const auto& n : nodes){ + + glm::vec3 n_pos = n.collider->getPosition(); + + std::vector< double > p = { + static_cast(n_pos.x), + static_cast(n_pos.y) + }; + + kd_points.push_back(p); + } + + kd_tree = std::make_unique(kd_points); + + kd_result.resize(4); +} + +void ofApp::queryKD(glm::vec3& _position, int k){ + + kd_result.clear(); + + vector kd_input = { + static_cast(_position.x), + static_cast(_position.y) + }; + + auto res = kd_tree->nearest_pointIndices(kd_input, k); + + for (const auto& r : res) { + kd_result.push_back(r); + } + + size_t index = kd_result[0].second; + + std::cout << "camera positions: " << _position << ", nearest node: "<< nodes[index].tsne_position << " " << index << std::endl; + +} + diff --git a/src/ofApp.h b/src/ofApp.h index fd01eae..5b7ee5b 100644 --- a/src/ofApp.h +++ b/src/ofApp.h @@ -11,6 +11,7 @@ #include "network/Request.h" #include "network/Server.h" #include "ofxTSNE.h" +#include "utils/KDTree.hpp" class ofApp : public ofBaseApp{ @@ -21,6 +22,9 @@ class ofApp : public ofBaseApp{ void getNearestImages(); void createNodes(std::string json_path); std::vector> createDoubleVectorFromNodes(const std::vector& nodes); + void buildKDTree(); + void queryKD(glm::vec3& _position, int k); + void buildMeshes(); void exit(); ofEasyCam cam; @@ -37,9 +41,11 @@ class ofApp : public ofBaseApp{ ofFbo model_esp_out_fbo; ofFbo esp_comp_fbo; + ofFbo kd_out; ofImage model_image_esp; ofShader shaders; + ofShader esp_shader; ofPlanePrimitive plane; ofTexture bayer; @@ -77,5 +83,8 @@ class ofApp : public ofBaseApp{ bool runManually = false; float tsne_scale = 1000; + /* kd tree */ + std::unique_ptr kd_tree; + std::vector kd_result; }; diff --git a/src/utils/KDTree.cpp b/src/utils/KDTree.cpp new file mode 100644 index 0000000..c0017f4 --- /dev/null +++ b/src/utils/KDTree.cpp @@ -0,0 +1,292 @@ +/// @file KDTree.cpp +/// @author J. Frederico Carvalho +/// +/// This is an adaptation of the KD-tree implementation in rosetta code +/// https://rosettacode.org/wiki/K-d_tree +/// +/// It is a reimplementation of the C code using C++. It also includes a few +/// more queries than the original, namely finding all points at a distance +/// smaller than some given distance to a point. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "KDTree.hpp" + +KDNode::KDNode() = default; + +KDNode::KDNode(point_t const& pt, size_t const& idx_, KDNodePtr const& left_, + KDNodePtr const& right_) { + x = pt; + index = idx_; + left = left_; + right = right_; +} + +KDNode::KDNode(pointIndex const& pi, KDNodePtr const& left_, + KDNodePtr const& right_) { + x = pi.first; + index = pi.second; + left = left_; + right = right_; +} + +KDNode::~KDNode() = default; + +double KDNode::coord(size_t const& idx) { return x.at(idx); } +KDNode::operator bool() { return (!x.empty()); } +KDNode::operator point_t() { return x; } +KDNode::operator size_t() { return index; } +KDNode::operator pointIndex() { return std::make_pair(x, index); } + +KDNodePtr NewKDNodePtr() { + KDNodePtr mynode = std::make_shared(); + return mynode; +} + +inline double dist2(point_t const& a, point_t const& b) { + assert(a.size() == b.size()); + double distc = 0; + for (size_t i = 0; i < a.size(); i++) { + double di = a.at(i) - b.at(i); + distc += di * di; + } + return distc; +} + +inline double dist2(KDNodePtr const& a, KDNodePtr const& b) { + return dist2(a->x, b->x); +} + +comparer::comparer(size_t idx_) : idx{idx_} {} + +inline bool comparer::compare_idx(pointIndex const& a, pointIndex const& b) { + return (a.first.at(idx) < b.first.at(idx)); +} + +inline void sort_on_idx(pointIndexArr::iterator const& begin, + pointIndexArr::iterator const& end, size_t idx) { + comparer comp(idx); + comp.idx = idx; + + using std::placeholders::_1; + using std::placeholders::_2; + + std::nth_element(begin, begin + std::distance(begin, end) / 2, end, + std::bind(&comparer::compare_idx, comp, _1, _2)); +} + +namespace detail { +inline bool compare_node_distance(std::pair a, + std::pair b) { + return a.second < b.second; +} +} // namespace detail + +using pointVec = std::vector; + +KDNodePtr KDTree::make_tree(pointIndexArr::iterator const& begin, + pointIndexArr::iterator const& end, + size_t const& level) { + if (begin == end) { + return leaf_; // empty tree + } + + assert(std::distance(begin, end) > 0); + + size_t const dim = begin->first.size(); + sort_on_idx(begin, end, level); + + auto const num_points = std::distance(begin, end); + auto const middle{std::next(begin, num_points / 2)}; + + size_t const next_level{(level + 1) % dim}; + KDNodePtr const left{make_tree(begin, middle, next_level)}; + KDNodePtr const right{make_tree(std::next(middle), end, next_level)}; + return std::make_shared(*middle, left, right); +} + +KDTree::KDTree(pointVec point_array) : leaf_{std::make_shared()} { + pointIndexArr arr; + for (size_t i = 0; i < point_array.size(); i++) { + arr.emplace_back(point_array.at(i), i); + } + root_ = KDTree::make_tree(arr.begin(), arr.end(), 0 /* level */); +} + +void KDTree::node_query_( + KDNodePtr const& branch, point_t const& pt, size_t const& level, + size_t const& num_nearest, + std::list>& k_nearest_buffer) { + if (!static_cast(*branch)) { + return; + } + knearest_(branch, pt, level, num_nearest, k_nearest_buffer); + double const dl = dist2(branch->x, pt); + // assert(*branch); + auto const node_distance = std::make_pair(branch, dl); + auto const insert_it = + std::upper_bound(k_nearest_buffer.begin(), k_nearest_buffer.end(), + node_distance, detail::compare_node_distance); + if (insert_it != k_nearest_buffer.end() || + k_nearest_buffer.size() < num_nearest) { + k_nearest_buffer.insert(insert_it, node_distance); + } + while (k_nearest_buffer.size() > num_nearest) { + k_nearest_buffer.pop_back(); + } +} + +void KDTree::knearest_( + KDNodePtr const& branch, point_t const& pt, size_t const& level, + size_t const& num_nearest, + std::list>& k_nearest_buffer) { + if (branch == nullptr || !static_cast(*branch)) { + return; + } + + point_t branch_pt{*branch}; + size_t dim = branch_pt.size(); + assert(dim != 0); + assert(dim == pt.size()); + + double const dx = branch_pt.at(level) - pt.at(level); + double const dx2 = dx * dx; + + // select which branch makes sense to check + KDNodePtr const close_branch = (dx > 0) ? branch->left : branch->right; + KDNodePtr const far_branch = (dx > 0) ? branch->right : branch->left; + + size_t const next_level = (level + 1) % dim; + node_query_(close_branch, pt, next_level, num_nearest, k_nearest_buffer); + + // only check the other branch if it makes sense to do so + if (dx2 < k_nearest_buffer.back().second || + k_nearest_buffer.size() < num_nearest) { + node_query_(far_branch, pt, next_level, num_nearest, k_nearest_buffer); + } +}; + +// default caller +KDNodePtr KDTree::nearest_(point_t const& pt) { + size_t level = 0; + std::list> k_buffer{}; + k_buffer.emplace_back(root_, dist2(static_cast(*root_), pt)); + knearest_(root_, // beginning of tree + pt, // point we are querying + level, // start from level 0 + 1, // number of nearest neighbours to return in k_buffer + k_buffer // list of k nearest neigbours (to be filled) + ); + if (k_buffer.size() > 0) { + return k_buffer.front().first; + } + return nullptr; +}; + +point_t KDTree::nearest_point(point_t const& pt) { + return static_cast(*nearest_(pt)); +} + +size_t KDTree::nearest_index(point_t const& pt) { + return static_cast(*nearest_(pt)); +} + +pointIndex KDTree::nearest_pointIndex(point_t const& pt) { + KDNodePtr Nearest = nearest_(pt); + return static_cast(*Nearest); +} + +pointIndexArr KDTree::nearest_pointIndices(point_t const& pt, + size_t const& num_nearest) { + size_t level = 0; + std::list> k_buffer{}; + k_buffer.emplace_back(root_, dist2(static_cast(*root_), pt)); + knearest_(root_, // beginning of tree + pt, // point we are querying + level, // start from level 0 + num_nearest, // number of nearest neighbours to return in k_buffer + k_buffer); // list of k nearest neigbours (to be filled) + pointIndexArr output{num_nearest}; + std::transform(k_buffer.begin(), k_buffer.end(), output.begin(), + [](auto const& nodeptr_dist) { + return static_cast(*(nodeptr_dist.first)); + }); + return output; +} + +pointVec KDTree::nearest_points(point_t const& pt, size_t const& num_nearest) { + auto const k_nearest{nearest_pointIndices(pt, num_nearest)}; + pointVec k_nearest_points(k_nearest.size()); + std::transform(k_nearest.begin(), k_nearest.end(), k_nearest_points.begin(), + [](pointIndex const& x) { return x.first; }); + return k_nearest_points; +} + +indexArr KDTree::nearest_indices(point_t const& pt, size_t const& num_nearest) { + auto const k_nearest{nearest_pointIndices(pt, num_nearest)}; + indexArr k_nearest_indices(k_nearest.size()); + std::transform(k_nearest.begin(), k_nearest.end(), + k_nearest_indices.begin(), + [](pointIndex const& x) { return x.second; }); + return k_nearest_indices; +} + +void KDTree::neighborhood_(KDNodePtr const& branch, point_t const& pt, + double const& rad2, size_t const& level, + pointIndexArr& nbh) { + if (!bool(*branch)) { + // branch has no point, means it is a leaf, + // no points to add + return; + } + + size_t const dim = pt.size(); + + double const d = dist2(static_cast(*branch), pt); + double const dx = static_cast(*branch).at(level) - pt.at(level); + double const dx2 = dx * dx; + + if (d <= rad2) { + nbh.push_back(static_cast(*branch)); + } + + KDNodePtr const close_branch = (dx > 0) ? branch->left : branch->right; + KDNodePtr const far_branch = (dx > 0) ? branch->right : branch->left; + + size_t const next_level{(level + 1) % dim}; + neighborhood_(close_branch, pt, rad2, next_level, nbh); + if (dx2 < rad2) { + neighborhood_(far_branch, pt, rad2, next_level, nbh); + } +} + +pointIndexArr KDTree::neighborhood(point_t const& pt, double const& rad) { + pointIndexArr nbh; + neighborhood_(root_, pt, rad * rad, /*level*/ 0, nbh); + return nbh; +} + +pointVec KDTree::neighborhood_points(point_t const& pt, double const& rad) { + auto nbh = std::make_shared(); + neighborhood_(root_, pt, rad * rad, /*level*/ 0, *nbh); + pointVec nbhp(nbh->size()); + auto const first = [](pointIndex const& x) { return x.first; }; + std::transform(nbh->begin(), nbh->end(), nbhp.begin(), first); + return nbhp; +} + +indexArr KDTree::neighborhood_indices(point_t const& pt, double const& rad) { + auto nbh = std::make_shared(); + neighborhood_(root_, pt, rad * rad, /*level*/ 0, *nbh); + indexArr nbhi(nbh->size()); + auto const second = [](pointIndex const& x) { return x.second; }; + std::transform(nbh->begin(), nbh->end(), nbhi.begin(), second); + return nbhi; +} \ No newline at end of file diff --git a/src/utils/KDTree.hpp b/src/utils/KDTree.hpp new file mode 100644 index 0000000..04b39a2 --- /dev/null +++ b/src/utils/KDTree.hpp @@ -0,0 +1,179 @@ +#pragma once + +/// @file KDTree.hpp +/// @author J. Frederico Carvalho +/// +/// This is an adaptation of the KD-tree implementation in rosetta code +/// https://rosettacode.org/wiki/K-d_tree +/// It is a reimplementation of the C code using C++. +/// It also includes a few more queries than the original + +#include +#include +#include +#include +#include + +/// The point type (vector of double precision floats) +using point_t = std::vector; + +/// Array of indices +using indexArr = std::vector; + +/// Pair of point and Index +using pointIndex = typename std::pair, size_t>; + +class KDNode { + public: + using KDNodePtr = std::shared_ptr; + size_t index; + point_t x; + KDNodePtr left; + KDNodePtr right; + + // initializer + KDNode(); + KDNode(point_t const&, size_t const&, KDNodePtr const&, KDNodePtr const&); + KDNode(pointIndex const&, KDNodePtr const&, KDNodePtr const&); + ~KDNode(); + + // getter + double coord(size_t const&); + + // conversions + explicit operator bool(); + explicit operator point_t(); + explicit operator size_t(); + explicit operator pointIndex(); +}; + +using KDNodePtr = std::shared_ptr; + +KDNodePtr NewKDNodePtr(); + +// square euclidean distance +inline double dist2(point_t const&, point_t const&); +inline double dist2(KDNodePtr const&, KDNodePtr const&); + +// Need for sorting +class comparer { + public: + size_t idx; + explicit comparer(size_t idx_); + inline bool compare_idx(std::pair, size_t> const&, // + std::pair, size_t> const& // + ); +}; + +using pointIndexArr = typename std::vector; + +inline void sort_on_idx(pointIndexArr::iterator const&, // + pointIndexArr::iterator const&, // + size_t idx); + +using pointVec = std::vector; + +class KDTree { + + public: + KDTree() = default; + + /// Build a KDtree + explicit KDTree(pointVec point_array); + + /// Get the point which lies closest to the input point. + /// @param pt input point. + point_t nearest_point(point_t const& pt); + + /// Get the index of the point which lies closest to the input point. + /// + /// @param pt input point. + size_t nearest_index(point_t const& pt); + + /// Get the point and its index which lies closest to the input point. + /// + /// @param pt input point. + pointIndex nearest_pointIndex(point_t const& pt); + + /// Get both the point and the index of the points closest to the input + /// point. + /// + /// @param pt input point. + /// @param num_nearest Number of nearest points to return. + /// + /// @returns a vector containing the points and their respective indices + /// which are at a distance smaller than rad to the input point. + pointIndexArr nearest_pointIndices(point_t const& pt, + size_t const& num_nearest); + + /// Get the nearest set of points to the given input point. + /// + /// @param pt input point. + /// @param num_nearest Number of nearest points to return. + /// + /// @returns a vector containing the points which are at a distance smaller + /// than rad to the input point. + pointVec nearest_points(point_t const& pt, size_t const& num_nearest); + + /// Get the indices of points closest to the input point. + /// + /// @param pt input point. + /// @param num_nearest Number of nearest points to return. + /// + /// @returns a vector containing the indices of the points which are at a + /// distance smaller than rad to the input point. + indexArr nearest_indices(point_t const& pt, size_t const& num_nearest); + + /// Get both the point and the index of the points which are at a distance + /// smaller than the input radius to the input point. + /// + /// @param pt input point. + /// @param rad input radius. + /// + /// @returns a vector containing the points and their respective indices + /// which are at a distance smaller than rad to the input point. + pointIndexArr neighborhood(point_t const& pt, double const& rad); + + /// Get the points that are at a distance to the input point which is + /// smaller than the input radius. + /// + /// @param pt input point. + /// @param rad input radius. + /// + /// @returns a vector containing the points which are at a distance smaller + /// than rad to the input point. + pointVec neighborhood_points(point_t const& pt, double const& rad); + + /// Get the indices of points that are at a distance to the input point + /// which is smaller than the input radius. + /// + /// @param pt input point. + /// @param rad input radius. + /// + /// @returns a vector containing the indices of the points which are at a + /// distance smaller than rad to the input point. + indexArr neighborhood_indices(point_t const& pt, double const& rad); + + private: + KDNodePtr make_tree(pointIndexArr::iterator const& begin, + pointIndexArr::iterator const& end, + size_t const& level); + + void knearest_(KDNodePtr const& branch, point_t const& pt, + size_t const& level, size_t const& num_nearest, + std::list>& k_nearest_buffer); + + void node_query_(KDNodePtr const& branch, point_t const& pt, + size_t const& level, size_t const& num_nearest, + std::list>& k_nearest_buffer); + + // default caller + KDNodePtr nearest_(point_t const& pt); + + void neighborhood_(KDNodePtr const& branch, point_t const& pt, + double const& rad2, size_t const& level, + pointIndexArr& nbh); + + KDNodePtr root_; + KDNodePtr leaf_; +}; \ No newline at end of file