Browse Source

kd tree, image2mesh, kd query, esp32 frag

master
cailean 4 months ago
parent
commit
8fa99dfce1
  1. 1
      .gitignore
  2. BIN
      bin/data/debug_binaryMat.png
  3. BIN
      bin/data/new_tree.bin
  4. 38
      bin/data/shaders/espDepth.frag
  5. 13
      bin/data/shaders/espDepth.vert
  6. BIN
      bin/image-to-mesh
  7. 60
      src/Bullet.cpp
  8. 4
      src/Bullet.h
  9. 136
      src/ofApp.cpp
  10. 9
      src/ofApp.h
  11. 292
      src/utils/KDTree.cpp
  12. 179
      src/utils/KDTree.hpp

1
.gitignore

@ -12,6 +12,7 @@
/bin/data/images
/bin/data/recordings
/bin/data/models
/bin/data/dataset
/bin/libs/
#########

BIN
bin/data/debug_binaryMat.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

After

Width:  |  Height:  |  Size: 3.9 KiB

BIN
bin/data/new_tree.bin

Binary file not shown.

38
bin/data/shaders/espDepth.frag

@ -0,0 +1,38 @@
#version 150
precision highp float;
uniform sampler2DRect tex0; // Your FBO texture (RGB FBO)
uniform float texW; // FBO texture width
uniform float texH; // FBO texture height
out vec4 fragColor; // Output color
in vec2 varyingtexcoord;
void main() {
float threshold = 0.7;
float depth = texture(tex0, varyingtexcoord).r;
// Determine which quadrant we're in
int quadrantX = int(varyingtexcoord.x / (256.0 / 2.0));
int quadrantY = int(varyingtexcoord.y / (336.0 / 2.0));
int quadrant = quadrantY * 2 + quadrantX;
vec3 baseColor;
vec3 mixColor = vec3(0.0); // White for mixing in all quadrants
// Choose base color based on quadrant
if (quadrant == 0) {
baseColor = vec3(1.0, 0.0, 0.0); // Red for bottom-left
} else if (quadrant == 1) {
baseColor = vec3(0.0, 1.0, 0.0); // Green for bottom-right
} else if (quadrant == 2) {
baseColor = vec3(0.0, 0.0, 1.0); // Blue for top-left
} else {
baseColor = vec3(1.0, 1.0, 1.0); // White for top-right
}
// Mix the base color with white based on the depth
vec3 color = mix(mixColor, baseColor, depth);
fragColor = vec4(color, 1.0);
}

13
bin/data/shaders/espDepth.vert

@ -0,0 +1,13 @@
#version 150
uniform mat4 modelViewProjectionMatrix;
in vec4 position;
in vec2 texcoord;
out vec2 varyingtexcoord;
void main(){
varyingtexcoord = texcoord;
gl_Position = modelViewProjectionMatrix * position;
}

BIN
bin/image-to-mesh

Binary file not shown.

60
src/Bullet.cpp

@ -26,7 +26,7 @@ void Bullet::setup(vector<Node>& _nodes){
shader.load("shaders/vertex_snap");
/* assign nodes */
nodes = _nodes;
//nodes = _nodes;
std::cout << workerThreads.size() << std::endl;
}
@ -128,19 +128,18 @@ void Bullet::draw(){
now = false;
}
void Bullet::addMesh(ofMesh _mesh, ofMesh _simple_mesh, ofTexture _tex){
void Bullet::addMesh(ofMesh _mesh, ofMesh _simple_mesh, Node& _node){
std::lock_guard<std::mutex> lock(shapeMutex);
Node n;
n.tex = _tex;
// n.tex = _tex;
float rand = ofRandom(0.01, 0.02);
glm::vec3 random_szie(rand, rand, rand);
n.scale = random_szie;
_node.scale = random_szie;
ofQuaternion startRot = ofQuaternion(0., 0., 0., PI);
ofVec3f target_location = ofVec3f( ofRandom(0, 0), ofRandom(0, 0), -5 );
ofVec3f start_location = ofVec3f( ofRandom(-300, 300), ofRandom(-300, 300), -5 );
// ofVec3f target_location = ofVec3f( ofRandom(0, 0), ofRandom(0, 0), -5 );
ofVec3f start_location = ofVec3f( ofRandom(-3000, 3000), ofRandom(-3000, 3000) -5 );
ofxBulletCustomShape* s = new ofxBulletCustomShape();
s->addMesh(_simple_mesh, n.scale * 1.4, true);
s->addMesh(_simple_mesh, _node.scale * 1.4, true);
s->create( world.world, start_location, startRot, 3.);
s->add();
s->getRigidBody()->setAngularFactor(btVector3(0, 0, 1));
@ -150,15 +149,15 @@ void Bullet::addMesh(ofMesh _mesh, ofMesh _simple_mesh, ofTexture _tex){
s->getRigidBody()->setRestitution(btScalar(0.5));
s->getRigidBody()->setFriction(btScalar(1.0));
positions.push_back(target_location);
//positions.push_back(target_location);
// Set how the col mesh is drawn!
_simple_mesh.setMode(OF_PRIMITIVE_LINES);
n.collider = s;
n.mesh = _mesh;
n.col_mesh = _simple_mesh;
nodes.push_back(n);
_node.collider = s;
_node.mesh = _mesh;
_node.col_mesh = _simple_mesh;
// nodes.push_back(n);
}
void Bullet::workerThreadFunction(int threadId) {
@ -209,24 +208,25 @@ void Bullet::updateShapeBatch(size_t start, size_t end) {
std::lock_guard<std::mutex> lock(shapeMutex);
glm::vec3 pos = nodes[i].collider->getPosition();
glm::vec3 direction = glm::vec3(0, 0, -5) - pos;
glm::vec3 target_pos(nodes[i].tsne_position.x, nodes[i].tsne_position.y, -5);
glm::vec3 direction = target_pos - pos;
float dist_sq = glm::length(direction);
glm::vec3 norm_dir = glm::normalize(direction);
nodes[i].collider->applyCentralForce(1.0f * norm_dir);
// Apply repulsion force if needed
if (shouldApplyRepulsion && i != repulsionNodeIndex) {
const Node& repulsionNode = nodes[repulsionNodeIndex];
glm::vec3 repulsionDir = pos - repulsionNode.collider->getPosition();
float repulsionDist = glm::length(repulsionDir);
if (repulsionDist < repulsionRadius && repulsionDist > 0) {
repulsionDir = glm::normalize(repulsionDir);
float forceMagnitude = repulsionStrength * (1.0f - repulsionDist / repulsionRadius);
glm::vec3 repulsionForce = repulsionDir * forceMagnitude;
nodes[i].collider->applyCentralForce(repulsionForce * 100.0f);
}
}
// if (shouldApplyRepulsion && i != repulsionNodeIndex) {
// const Node& repulsionNode = nodes[repulsionNodeIndex];
// glm::vec3 repulsionDir = pos - repulsionNode.collider->getPosition();
// float repulsionDist = glm::length(repulsionDir);
// if (repulsionDist < repulsionRadius && repulsionDist > 0) {
// repulsionDir = glm::normalize(repulsionDir);
// float forceMagnitude = repulsionStrength * (1.0f - repulsionDist / repulsionRadius);
// glm::vec3 repulsionForce = repulsionDir * forceMagnitude;
// nodes[i].collider->applyCentralForce(repulsionForce * 100.0f);
// }
// }
}
}
@ -267,4 +267,12 @@ bool Bullet::checkNodeVisibility(const Node& n){
} else {
return false;
}
}
glm::vec3 Bullet::getCameraPosition(){
return camera.getPosition();
}
void Bullet::setNodes(vector<Node>& _nodes){
nodes = _nodes;
}

4
src/Bullet.h

@ -22,12 +22,14 @@ class Bullet{
void setup(vector<Node>& _nodes);
void update();
void draw();
void addMesh(ofMesh _mesh, ofMesh _simple_mesh, ofTexture _tex);
void addMesh(ofMesh _mesh, ofMesh _simple_mesh, Node& _node);
float easeInOutCubic(float t);
float calculateGridSize(float zoom);
void setNewCameraEndpoint();
glm::vec4 getCameraBounds(const ofCamera& cam);
bool checkNodeVisibility(const Node& n);
glm::vec3 getCameraPosition();
void setNodes(vector<Node>& _nodes);
ofxBulletWorldRigid world;
vector <ofxBulletBox*> bounds;
ofxBulletCustomShape* boundsShape;

136
src/ofApp.cpp

@ -11,6 +11,7 @@ void ofApp::setup(){
model_esp_out_fbo.allocate(128 * 2, 168 * 2, GL_RGB);
/* k-d image comp (4-images) */
kd_out.allocate(128 * 2, 168 * 2, GL_RGB);
esp_comp_fbo.allocate(128 * 2, 168 * 2, GL_RGB);
/* input images for model */
@ -34,11 +35,12 @@ void ofApp::setup(){
/* load */
shaders.load("shaders/dither");
esp_shader.load("shaders/espDepth");
ORTCHAR_T* modelPath = "/home/cailean/Desktop/openframeworks/of_v0.12.0_linux64gcc6_release/apps/myApps/image-to-mesh/bin/data/models/depth_anything_v2_vits.onnx";
/* setup */
//bullet.setup(nodes);
bullet.setup(nodes);
depth_onnx.Setup(modelPath, true, true);
depth_onnx_esp.Setup(modelPath, true, true);
@ -46,52 +48,31 @@ void ofApp::setup(){
depth_thread.setup(&model_image, &model_outptut_fbo, &depth_onnx);
depth_esp.setup(&model_image_esp, &model_esp_out_fbo, &depth_onnx_esp);
/* mesh generation test */
std::string path = "images";
ofDirectory dir(path);
dir.allowExt("png");
dir.listDir();
ofLoadImage(bayer, "images/bayer.png");
bayer.setTextureWrap(GL_REPEAT, GL_REPEAT);
// Loop through files and load images
for (size_t i = 0; i < dir.size(); i++) {
std::string filePath = dir.getPath(i);
ofImage img;
if (img.load(filePath)) { // Load the image
images.push_back(img); // Add the image to the list
} else {
ofLog() << "Failed to load: " << filePath; // Log if failed to load
}
}
// vector<ofMesh> mesh_list;
createNodes("data/json/embeddings.json");
// for(auto& img : images){
// mesh_list = mesh_generator.generateSimplifiedMesh(img);
// bullet.addMesh(mesh_list[0], mesh_list[1], img.getTexture());
// mesh_list.clear();
// }
buildMeshes();
last_updated_time = ofGetElapsedTimef();
buildKDTree();
createNodes("data/json/embeddings.json");
bullet.setNodes(nodes);
server = std::make_unique<Server>(6762, embed, vp_tree, nodes, false, "192.168.0.253", 2000, "search");
server->start();
last_updated_time = ofGetElapsedTimef();
}
//--------------------------------------------------------------
void ofApp::update(){
server->update(model_esp_out_fbo);
server->update(esp_comp_fbo);
float current_time = ofGetElapsedTimef();
if(current_time - last_updated_time >= 3){
buildKDTree();
getNearestImages();
last_updated_time = current_time;
}
@ -120,7 +101,7 @@ void ofApp::update(){
std::cout << "Model did not run" << std::endl;
}
//bullet.update();
bullet.update();
}
//--------------------------------------------------------------
@ -129,7 +110,7 @@ void ofApp::draw(){
ofPushStyle();
map_fbo.begin();
ofClear(ofColor::grey);
//bullet.draw();
bullet.draw();
map_fbo.end();
ofPopStyle();
@ -152,7 +133,17 @@ void ofApp::draw(){
comp_fbo.draw(0,0);
model_outptut_fbo.draw(ofGetWindowWidth() / 2, 0);
esp_comp_fbo.begin();
esp_shader.begin();
esp_shader.setUniformTexture("tex0", model_esp_out_fbo.getTexture(), 0);
model_esp_out_fbo.draw(0, 0);
esp_comp_fbo.end();
esp_shader.end();
esp_comp_fbo.draw(0, 0);
//server->print();
//shader_fbo.draw(0, 0);
@ -182,25 +173,28 @@ void ofApp::draw(){
/* creates an fbo with four cropped images, to preprare for model input */
void ofApp::getNearestImages(){
esp_comp_fbo.begin();
glm::vec3 cam_position = bullet.getCameraPosition();
queryKD(cam_position, 4);
kd_out.begin();
ofClear(255, 255, 255, 0);
ofImage random_image;
ofImage selected_image;
int imageIndex = 0;
for(auto& img : esp_images){
random_image = images[ofRandom(images.size() - 1)];
selected_image = nodes[kd_result[imageIndex].second].img;
// Calculate the scaling factor
float widthRatio = 128.0f / random_image.getWidth();
float heightRatio = 168.0f / random_image.getHeight();
float widthRatio = 128.0f / selected_image.getWidth();
float heightRatio = 168.0f / selected_image.getHeight();
float scale = std::max(widthRatio, heightRatio);
// Calculate new dimensions
int newWidth = std::ceil(random_image.getWidth() * scale);
int newHeight = std::ceil(random_image.getHeight() * scale);
int newWidth = std::ceil(selected_image.getWidth() * scale);
int newHeight = std::ceil(selected_image.getHeight() * scale);
// Resize the image
random_image.resize(newWidth, newHeight);
selected_image.resize(newWidth, newHeight);
// Calculate the crop position to center the image
int cropX = (newWidth - 128) / 2;
@ -215,15 +209,15 @@ void ofApp::getNearestImages(){
int drawY = (imageIndex / 2) * 168;
// Draw the resized and cropped image
random_image.drawSubsection(drawX, drawY, cropWidth, cropHeight, cropX, cropY);
selected_image.drawSubsection(drawX, drawY, cropWidth, cropHeight, cropX, cropY);
imageIndex++;
if (imageIndex >= 4) break; // Stop after drawing 4 images
}
esp_comp_fbo.end();
kd_out.end();
esp_comp_fbo.readToPixels(esp_comp_pixels);
kd_out.readToPixels(esp_comp_pixels);
model_image_esp.setFromPixels(esp_comp_pixels);
}
@ -244,6 +238,7 @@ void ofApp::createNodes(std::string json_path){
if(j.contains("vector") && j["vector"].is_array()){
Node n;
n.img.load(j["image"]);
n.tex = n.img.getTexture();
std::vector<float> t_embedding;
for (const auto& value: j["vector"]){
@ -278,9 +273,8 @@ void ofApp::createNodes(std::string json_path){
for(size_t i = 0; i < tsne_points.size(); i++){
const auto& vec = tsne_points[i];
auto& n = nodes[i];
n.tsne_position = (glm::vec3(vec[0] * tsne_scale, vec[1] * tsne_scale, -5.0f));
n.tsne_position = (glm::vec3(((vec[0] * 2) - 1) * tsne_scale, ((vec[1] * 2) - 1) * tsne_scale, -5.0f));
}
/* vp-test */
// auto queries = server->generateRandomVectors(10, 7);
@ -330,3 +324,55 @@ void ofApp::exit(){
server->close();
}
void ofApp::buildMeshes(){
vector<ofMesh> mesh_list;
ofLogNotice() << "building meshes: " << nodes.size();
for(auto& n : nodes){
mesh_list = mesh_generator.generateSimplifiedMesh(n.img);
bullet.addMesh(mesh_list[0], mesh_list[1], n);
mesh_list.clear();
}
ofLogNotice() << "finished building meshes";
}
void ofApp::buildKDTree(){
pointVec kd_points;
for(const auto& n : nodes){
glm::vec3 n_pos = n.collider->getPosition();
std::vector< double > p = {
static_cast<double>(n_pos.x),
static_cast<double>(n_pos.y)
};
kd_points.push_back(p);
}
kd_tree = std::make_unique<KDTree>(kd_points);
kd_result.resize(4);
}
void ofApp::queryKD(glm::vec3& _position, int k){
kd_result.clear();
vector<double> kd_input = {
static_cast<double>(_position.x),
static_cast<double>(_position.y)
};
auto res = kd_tree->nearest_pointIndices(kd_input, k);
for (const auto& r : res) {
kd_result.push_back(r);
}
size_t index = kd_result[0].second;
std::cout << "camera positions: " << _position << ", nearest node: "<< nodes[index].tsne_position << " " << index << std::endl;
}

9
src/ofApp.h

@ -11,6 +11,7 @@
#include "network/Request.h"
#include "network/Server.h"
#include "ofxTSNE.h"
#include "utils/KDTree.hpp"
class ofApp : public ofBaseApp{
@ -21,6 +22,9 @@ class ofApp : public ofBaseApp{
void getNearestImages();
void createNodes(std::string json_path);
std::vector<std::vector<double>> createDoubleVectorFromNodes(const std::vector<Node>& nodes);
void buildKDTree();
void queryKD(glm::vec3& _position, int k);
void buildMeshes();
void exit();
ofEasyCam cam;
@ -37,9 +41,11 @@ class ofApp : public ofBaseApp{
ofFbo model_esp_out_fbo;
ofFbo esp_comp_fbo;
ofFbo kd_out;
ofImage model_image_esp;
ofShader shaders;
ofShader esp_shader;
ofPlanePrimitive plane;
ofTexture bayer;
@ -77,5 +83,8 @@ class ofApp : public ofBaseApp{
bool runManually = false;
float tsne_scale = 1000;
/* kd tree */
std::unique_ptr<KDTree> kd_tree;
std::vector<pointIndex> kd_result;
};

292
src/utils/KDTree.cpp

@ -0,0 +1,292 @@
/// @file KDTree.cpp
/// @author J. Frederico Carvalho
///
/// This is an adaptation of the KD-tree implementation in rosetta code
/// https://rosettacode.org/wiki/K-d_tree
///
/// It is a reimplementation of the C code using C++. It also includes a few
/// more queries than the original, namely finding all points at a distance
/// smaller than some given distance to a point.
#include <algorithm>
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <limits>
#include <memory>
#include <vector>
#include "KDTree.hpp"
KDNode::KDNode() = default;
KDNode::KDNode(point_t const& pt, size_t const& idx_, KDNodePtr const& left_,
KDNodePtr const& right_) {
x = pt;
index = idx_;
left = left_;
right = right_;
}
KDNode::KDNode(pointIndex const& pi, KDNodePtr const& left_,
KDNodePtr const& right_) {
x = pi.first;
index = pi.second;
left = left_;
right = right_;
}
KDNode::~KDNode() = default;
double KDNode::coord(size_t const& idx) { return x.at(idx); }
KDNode::operator bool() { return (!x.empty()); }
KDNode::operator point_t() { return x; }
KDNode::operator size_t() { return index; }
KDNode::operator pointIndex() { return std::make_pair(x, index); }
KDNodePtr NewKDNodePtr() {
KDNodePtr mynode = std::make_shared<KDNode>();
return mynode;
}
inline double dist2(point_t const& a, point_t const& b) {
assert(a.size() == b.size());
double distc = 0;
for (size_t i = 0; i < a.size(); i++) {
double di = a.at(i) - b.at(i);
distc += di * di;
}
return distc;
}
inline double dist2(KDNodePtr const& a, KDNodePtr const& b) {
return dist2(a->x, b->x);
}
comparer::comparer(size_t idx_) : idx{idx_} {}
inline bool comparer::compare_idx(pointIndex const& a, pointIndex const& b) {
return (a.first.at(idx) < b.first.at(idx));
}
inline void sort_on_idx(pointIndexArr::iterator const& begin,
pointIndexArr::iterator const& end, size_t idx) {
comparer comp(idx);
comp.idx = idx;
using std::placeholders::_1;
using std::placeholders::_2;
std::nth_element(begin, begin + std::distance(begin, end) / 2, end,
std::bind(&comparer::compare_idx, comp, _1, _2));
}
namespace detail {
inline bool compare_node_distance(std::pair<KDNodePtr, double> a,
std::pair<KDNodePtr, double> b) {
return a.second < b.second;
}
} // namespace detail
using pointVec = std::vector<point_t>;
KDNodePtr KDTree::make_tree(pointIndexArr::iterator const& begin,
pointIndexArr::iterator const& end,
size_t const& level) {
if (begin == end) {
return leaf_; // empty tree
}
assert(std::distance(begin, end) > 0);
size_t const dim = begin->first.size();
sort_on_idx(begin, end, level);
auto const num_points = std::distance(begin, end);
auto const middle{std::next(begin, num_points / 2)};
size_t const next_level{(level + 1) % dim};
KDNodePtr const left{make_tree(begin, middle, next_level)};
KDNodePtr const right{make_tree(std::next(middle), end, next_level)};
return std::make_shared<KDNode>(*middle, left, right);
}
KDTree::KDTree(pointVec point_array) : leaf_{std::make_shared<KDNode>()} {
pointIndexArr arr;
for (size_t i = 0; i < point_array.size(); i++) {
arr.emplace_back(point_array.at(i), i);
}
root_ = KDTree::make_tree(arr.begin(), arr.end(), 0 /* level */);
}
void KDTree::node_query_(
KDNodePtr const& branch, point_t const& pt, size_t const& level,
size_t const& num_nearest,
std::list<std::pair<KDNodePtr, double>>& k_nearest_buffer) {
if (!static_cast<bool>(*branch)) {
return;
}
knearest_(branch, pt, level, num_nearest, k_nearest_buffer);
double const dl = dist2(branch->x, pt);
// assert(*branch);
auto const node_distance = std::make_pair(branch, dl);
auto const insert_it =
std::upper_bound(k_nearest_buffer.begin(), k_nearest_buffer.end(),
node_distance, detail::compare_node_distance);
if (insert_it != k_nearest_buffer.end() ||
k_nearest_buffer.size() < num_nearest) {
k_nearest_buffer.insert(insert_it, node_distance);
}
while (k_nearest_buffer.size() > num_nearest) {
k_nearest_buffer.pop_back();
}
}
void KDTree::knearest_(
KDNodePtr const& branch, point_t const& pt, size_t const& level,
size_t const& num_nearest,
std::list<std::pair<KDNodePtr, double>>& k_nearest_buffer) {
if (branch == nullptr || !static_cast<bool>(*branch)) {
return;
}
point_t branch_pt{*branch};
size_t dim = branch_pt.size();
assert(dim != 0);
assert(dim == pt.size());
double const dx = branch_pt.at(level) - pt.at(level);
double const dx2 = dx * dx;
// select which branch makes sense to check
KDNodePtr const close_branch = (dx > 0) ? branch->left : branch->right;
KDNodePtr const far_branch = (dx > 0) ? branch->right : branch->left;
size_t const next_level = (level + 1) % dim;
node_query_(close_branch, pt, next_level, num_nearest, k_nearest_buffer);
// only check the other branch if it makes sense to do so
if (dx2 < k_nearest_buffer.back().second ||
k_nearest_buffer.size() < num_nearest) {
node_query_(far_branch, pt, next_level, num_nearest, k_nearest_buffer);
}
};
// default caller
KDNodePtr KDTree::nearest_(point_t const& pt) {
size_t level = 0;
std::list<std::pair<KDNodePtr, double>> k_buffer{};
k_buffer.emplace_back(root_, dist2(static_cast<point_t>(*root_), pt));
knearest_(root_, // beginning of tree
pt, // point we are querying
level, // start from level 0
1, // number of nearest neighbours to return in k_buffer
k_buffer // list of k nearest neigbours (to be filled)
);
if (k_buffer.size() > 0) {
return k_buffer.front().first;
}
return nullptr;
};
point_t KDTree::nearest_point(point_t const& pt) {
return static_cast<point_t>(*nearest_(pt));
}
size_t KDTree::nearest_index(point_t const& pt) {
return static_cast<size_t>(*nearest_(pt));
}
pointIndex KDTree::nearest_pointIndex(point_t const& pt) {
KDNodePtr Nearest = nearest_(pt);
return static_cast<pointIndex>(*Nearest);
}
pointIndexArr KDTree::nearest_pointIndices(point_t const& pt,
size_t const& num_nearest) {
size_t level = 0;
std::list<std::pair<KDNodePtr, double>> k_buffer{};
k_buffer.emplace_back(root_, dist2(static_cast<point_t>(*root_), pt));
knearest_(root_, // beginning of tree
pt, // point we are querying
level, // start from level 0
num_nearest, // number of nearest neighbours to return in k_buffer
k_buffer); // list of k nearest neigbours (to be filled)
pointIndexArr output{num_nearest};
std::transform(k_buffer.begin(), k_buffer.end(), output.begin(),
[](auto const& nodeptr_dist) {
return static_cast<pointIndex>(*(nodeptr_dist.first));
});
return output;
}
pointVec KDTree::nearest_points(point_t const& pt, size_t const& num_nearest) {
auto const k_nearest{nearest_pointIndices(pt, num_nearest)};
pointVec k_nearest_points(k_nearest.size());
std::transform(k_nearest.begin(), k_nearest.end(), k_nearest_points.begin(),
[](pointIndex const& x) { return x.first; });
return k_nearest_points;
}
indexArr KDTree::nearest_indices(point_t const& pt, size_t const& num_nearest) {
auto const k_nearest{nearest_pointIndices(pt, num_nearest)};
indexArr k_nearest_indices(k_nearest.size());
std::transform(k_nearest.begin(), k_nearest.end(),
k_nearest_indices.begin(),
[](pointIndex const& x) { return x.second; });
return k_nearest_indices;
}
void KDTree::neighborhood_(KDNodePtr const& branch, point_t const& pt,
double const& rad2, size_t const& level,
pointIndexArr& nbh) {
if (!bool(*branch)) {
// branch has no point, means it is a leaf,
// no points to add
return;
}
size_t const dim = pt.size();
double const d = dist2(static_cast<point_t>(*branch), pt);
double const dx = static_cast<point_t>(*branch).at(level) - pt.at(level);
double const dx2 = dx * dx;
if (d <= rad2) {
nbh.push_back(static_cast<pointIndex>(*branch));
}
KDNodePtr const close_branch = (dx > 0) ? branch->left : branch->right;
KDNodePtr const far_branch = (dx > 0) ? branch->right : branch->left;
size_t const next_level{(level + 1) % dim};
neighborhood_(close_branch, pt, rad2, next_level, nbh);
if (dx2 < rad2) {
neighborhood_(far_branch, pt, rad2, next_level, nbh);
}
}
pointIndexArr KDTree::neighborhood(point_t const& pt, double const& rad) {
pointIndexArr nbh;
neighborhood_(root_, pt, rad * rad, /*level*/ 0, nbh);
return nbh;
}
pointVec KDTree::neighborhood_points(point_t const& pt, double const& rad) {
auto nbh = std::make_shared<pointIndexArr>();
neighborhood_(root_, pt, rad * rad, /*level*/ 0, *nbh);
pointVec nbhp(nbh->size());
auto const first = [](pointIndex const& x) { return x.first; };
std::transform(nbh->begin(), nbh->end(), nbhp.begin(), first);
return nbhp;
}
indexArr KDTree::neighborhood_indices(point_t const& pt, double const& rad) {
auto nbh = std::make_shared<pointIndexArr>();
neighborhood_(root_, pt, rad * rad, /*level*/ 0, *nbh);
indexArr nbhi(nbh->size());
auto const second = [](pointIndex const& x) { return x.second; };
std::transform(nbh->begin(), nbh->end(), nbhi.begin(), second);
return nbhi;
}

179
src/utils/KDTree.hpp

@ -0,0 +1,179 @@
#pragma once
/// @file KDTree.hpp
/// @author J. Frederico Carvalho
///
/// This is an adaptation of the KD-tree implementation in rosetta code
/// https://rosettacode.org/wiki/K-d_tree
/// It is a reimplementation of the C code using C++.
/// It also includes a few more queries than the original
#include <algorithm>
#include <functional>
#include <list>
#include <memory>
#include <vector>
/// The point type (vector of double precision floats)
using point_t = std::vector<double>;
/// Array of indices
using indexArr = std::vector<size_t>;
/// Pair of point and Index
using pointIndex = typename std::pair<std::vector<double>, size_t>;
class KDNode {
public:
using KDNodePtr = std::shared_ptr<KDNode>;
size_t index;
point_t x;
KDNodePtr left;
KDNodePtr right;
// initializer
KDNode();
KDNode(point_t const&, size_t const&, KDNodePtr const&, KDNodePtr const&);
KDNode(pointIndex const&, KDNodePtr const&, KDNodePtr const&);
~KDNode();
// getter
double coord(size_t const&);
// conversions
explicit operator bool();
explicit operator point_t();
explicit operator size_t();
explicit operator pointIndex();
};
using KDNodePtr = std::shared_ptr<KDNode>;
KDNodePtr NewKDNodePtr();
// square euclidean distance
inline double dist2(point_t const&, point_t const&);
inline double dist2(KDNodePtr const&, KDNodePtr const&);
// Need for sorting
class comparer {
public:
size_t idx;
explicit comparer(size_t idx_);
inline bool compare_idx(std::pair<std::vector<double>, size_t> const&, //
std::pair<std::vector<double>, size_t> const& //
);
};
using pointIndexArr = typename std::vector<pointIndex>;
inline void sort_on_idx(pointIndexArr::iterator const&, //
pointIndexArr::iterator const&, //
size_t idx);
using pointVec = std::vector<point_t>;
class KDTree {
public:
KDTree() = default;
/// Build a KDtree
explicit KDTree(pointVec point_array);
/// Get the point which lies closest to the input point.
/// @param pt input point.
point_t nearest_point(point_t const& pt);
/// Get the index of the point which lies closest to the input point.
///
/// @param pt input point.
size_t nearest_index(point_t const& pt);
/// Get the point and its index which lies closest to the input point.
///
/// @param pt input point.
pointIndex nearest_pointIndex(point_t const& pt);
/// Get both the point and the index of the points closest to the input
/// point.
///
/// @param pt input point.
/// @param num_nearest Number of nearest points to return.
///
/// @returns a vector containing the points and their respective indices
/// which are at a distance smaller than rad to the input point.
pointIndexArr nearest_pointIndices(point_t const& pt,
size_t const& num_nearest);
/// Get the nearest set of points to the given input point.
///
/// @param pt input point.
/// @param num_nearest Number of nearest points to return.
///
/// @returns a vector containing the points which are at a distance smaller
/// than rad to the input point.
pointVec nearest_points(point_t const& pt, size_t const& num_nearest);
/// Get the indices of points closest to the input point.
///
/// @param pt input point.
/// @param num_nearest Number of nearest points to return.
///
/// @returns a vector containing the indices of the points which are at a
/// distance smaller than rad to the input point.
indexArr nearest_indices(point_t const& pt, size_t const& num_nearest);
/// Get both the point and the index of the points which are at a distance
/// smaller than the input radius to the input point.
///
/// @param pt input point.
/// @param rad input radius.
///
/// @returns a vector containing the points and their respective indices
/// which are at a distance smaller than rad to the input point.
pointIndexArr neighborhood(point_t const& pt, double const& rad);
/// Get the points that are at a distance to the input point which is
/// smaller than the input radius.
///
/// @param pt input point.
/// @param rad input radius.
///
/// @returns a vector containing the points which are at a distance smaller
/// than rad to the input point.
pointVec neighborhood_points(point_t const& pt, double const& rad);
/// Get the indices of points that are at a distance to the input point
/// which is smaller than the input radius.
///
/// @param pt input point.
/// @param rad input radius.
///
/// @returns a vector containing the indices of the points which are at a
/// distance smaller than rad to the input point.
indexArr neighborhood_indices(point_t const& pt, double const& rad);
private:
KDNodePtr make_tree(pointIndexArr::iterator const& begin,
pointIndexArr::iterator const& end,
size_t const& level);
void knearest_(KDNodePtr const& branch, point_t const& pt,
size_t const& level, size_t const& num_nearest,
std::list<std::pair<KDNodePtr, double>>& k_nearest_buffer);
void node_query_(KDNodePtr const& branch, point_t const& pt,
size_t const& level, size_t const& num_nearest,
std::list<std::pair<KDNodePtr, double>>& k_nearest_buffer);
// default caller
KDNodePtr nearest_(point_t const& pt);
void neighborhood_(KDNodePtr const& branch, point_t const& pt,
double const& rad2, size_t const& level,
pointIndexArr& nbh);
KDNodePtr root_;
KDNodePtr leaf_;
};
Loading…
Cancel
Save