Cinder

  • Main Page
  • Related Pages
  • Namespaces
  • Classes
  • Files
  • File List
  • File Members

include/cinder/KdTree.h

Go to the documentation of this file.
00001 /*
00002  Copyright (c) 2009, The Barbarian Group
00003  All rights reserved.
00004 
00005  Redistribution and use in source and binary forms, with or without modification, are permitted provided that
00006  the following conditions are met:
00007 
00008     * Redistributions of source code must retain the above copyright notice, this list of conditions and
00009     the following disclaimer.
00010     * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and
00011     the following disclaimer in the documentation and/or other materials provided with the distribution.
00012 
00013  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00014  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
00015  PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
00016  ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
00017  TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
00018  HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00019  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00020  POSSIBILITY OF SUCH DAMAGE.
00021 */
00022 
00023 #pragma once
00024 
00025 #include "cinder/Cinder.h"
00026 #include "cinder/Vector.h"
00027 
00028 #include <vector>
00029 #include <float.h>
00030 #include <stdlib.h>
00031 #include <algorithm>
00032 #include <utility>
00033 
00034 namespace cinder {
00035 
00036 // KdTree Declarations
00037 template<unsigned char K>
00038 struct KdNode {
00039     void init( float p, uint32_t a) {
00040         splitPos = p;
00041         splitAxis = a;
00042         rightChild = ~0;
00043         hasLeftChild = 0;
00044     }
00045     void initLeaf() {
00046         splitAxis = K;
00047         rightChild = ~0;
00048         hasLeftChild = 0;
00049     }
00050     // KdNode Data
00051     float splitPos;
00052     uint32_t splitAxis:2;
00053     uint32_t hasLeftChild:1;
00054     uint32_t rightChild:29;
00055 };
00056 
00057 struct NullLookupProc {
00058  public:
00059     void process( uint32_t id, float distSqrd, float &maxDistSqrd ) {}
00060 };
00061 
00062 template <typename NodeData, unsigned char K=3, class LookupProc = NullLookupProc> class KdTree {
00063 public:
00064     typedef std::pair<const NodeData*, uint32_t> NodeDataIndex;
00065     
00066     // KdTree Public Methods
00067     template<typename NodeDataVector>
00068     KdTree( const NodeDataVector &data );
00069     KdTree() {}
00070     template<typename NodeDataVector>
00071     void initialize( const NodeDataVector &d );
00072     ~KdTree() {
00073         free( nodes );
00074         delete[] mNodeData;
00075     }
00076     void recursiveBuild( uint32_t nodeNum, uint32_t start, uint32_t end, std::vector<NodeDataIndex> &buildNodes );
00077     void lookup( const NodeData &p, const LookupProc &process, float maxDist ) const;
00078     void findNearest( float p[K], float result[K], uint32_t *resultIndex ) const;
00079     
00080 private:
00081     // KdTree Private Methods
00082     void privateLookup(uint32_t nodeNum, float p[K], const LookupProc &process, float &maxDistSquared) const;
00083     void privateFindNearest( uint32_t nodeNum, float p[K], float &maxDistSquared, float result[K], uint32_t *resultIndex ) const;
00084     // KdTree Private Data
00085     KdNode<K> *nodes;
00086     NodeDataIndex *mNodeData;
00087     uint32_t nNodes, nextFreeNode;
00088 };
00089 
00090 
00091 // Shims
00092 template<typename NDV>
00093 struct NodeDataVectorTraits
00094 {
00095     static uint32_t getSize( const NDV &ndv ) {
00096         return static_cast<uint32_t>( ndv.size() );
00097     }
00098 };
00099 
00100 template<typename NodeData>
00101 struct NodeDataTraits
00102 {
00103     static float getAxis( const NodeData &data, int axis ) {
00104         if( axis == 0 ) return data.x;
00105         else if( axis == 1 ) return data.y;
00106         else return (float)data.z;      
00107     }
00108     static float getAxis0( const NodeData &data ) { return static_cast<float>( data.x ); }
00109     static float getAxis1( const NodeData &data ) { return static_cast<float>( data.y ); }
00110     static float getAxis2( const NodeData &data ) { return static_cast<float>( data.z ); }
00111     static float distanceSquared( const NodeData &data, float k[3] ) {
00112         float result = ( data.x - k[0] ) * ( data.x - k[0] );
00113         result += ( data.y - k[1] ) * ( data.y - k[1] );
00114         result += ( data.z - k[2] ) * ( data.z - k[2] );
00115         return result;
00116     }
00117 };
00118 
00119 template<>
00120 struct NodeDataTraits<Vec2f>
00121 {
00122     static float getAxis( const Vec2f &data, int axis ) {
00123         if( axis == 0 ) return data.x;
00124         else return data.y; 
00125     }
00126     static float getAxis0( const Vec2f &data ) { return static_cast<float>( data.x ); }
00127     static float getAxis1( const Vec2f &data ) { return static_cast<float>( data.y ); }
00128     static float distanceSquared( const Vec2f &data, float k[2] ) {
00129         float result = ( data.x - k[0] ) * ( data.x - k[0] );
00130         result += ( data.y - k[1] ) * ( data.y - k[1] );
00131         return result;
00132     }
00133 };
00134 
00135 template<typename NodeData> struct CompareNode {
00136     CompareNode( int a ) { axis = a; }
00137     int axis;
00138     bool operator()(const std::pair<const NodeData*,uint32_t> &d1,
00139             const std::pair<const NodeData*,uint32_t> &d2) const {
00140         return NodeDataTraits<NodeData>::getAxis( *d1.first, axis ) == NodeDataTraits<NodeData>::getAxis( *d2.first, axis ) ? ( d1.first < d2.first ) :
00141             NodeDataTraits<NodeData>::getAxis( *d1.first, axis ) < NodeDataTraits<NodeData>::getAxis( *d2.first, axis );
00142     }
00143 };
00144 
00145 // KdTree Method Definitions
00146 template<typename NodeData, unsigned char K, typename LookupProc>
00147  template<typename NodeDataVector>
00148 KdTree<NodeData, K, LookupProc>::KdTree(const NodeDataVector &d)
00149 {
00150     initialize( d );
00151 }
00152 
00153 template<typename NodeData, unsigned char K, typename LookupProc>
00154  template<typename NodeDataVector>
00155 void KdTree<NodeData, K, LookupProc>::initialize( const NodeDataVector &d )
00156 {
00157     nNodes = NodeDataVectorTraits<NodeDataVector>::getSize( d );
00158     nextFreeNode = 1;
00159     nodes = (KdNode<K> *)malloc(nNodes * sizeof(KdNode<K>));
00160     mNodeData = new NodeDataIndex[nNodes];
00161     std::vector<NodeDataIndex> buildNodes;
00162     buildNodes.reserve( nNodes );
00163     for( uint32_t i = 0; i < nNodes; ++i )
00164         buildNodes.push_back( std::make_pair( &d[i], i ) );
00165     // Begin the KdTree building process
00166     recursiveBuild( 0, 0, nNodes, buildNodes );
00167 }
00168 
00169 template<typename NodeData, unsigned char K, typename LookupProc>
00170 void KdTree<NodeData, K, LookupProc>::recursiveBuild( uint32_t nodeNum, uint32_t start, uint32_t end, std::vector<NodeDataIndex> &buildNodes )
00171 {
00172     // Create leaf node of kd-tree if we've reached the bottom
00173     if( start + 1 == end) {
00174         nodes[nodeNum].initLeaf();
00175         mNodeData[nodeNum] = buildNodes[start];
00176         return;
00177     }
00178     // Choose split direction and partition data
00179     // Compute bounds of data from _start_ to _end_
00180     float boundMin[K], boundMax[K];
00181     for( unsigned char k = 0; k < K; ++k ) {
00182         boundMin[k] = FLT_MAX;
00183         boundMax[k] = FLT_MIN;
00184     }
00185     
00186     for( uint32_t i = start; i < end; ++i ) {
00187         for( uint8_t axis = 0; axis < K; axis++ ) {
00188             // NOT Compiling? you should define NOMINMAX
00189             boundMin[axis] = std::min( boundMin[axis], NodeDataTraits<NodeData>::getAxis( *buildNodes[i].first, axis ) );
00190             boundMax[axis] = std::max( boundMax[axis], NodeDataTraits<NodeData>::getAxis( *buildNodes[i].first, axis ) );
00191         }
00192     }
00193     int splitAxis = 0;
00194     float maxExtent = boundMax[0] - boundMin[0];
00195     for( unsigned char k = 1; k < K; ++k ) {
00196         if( boundMax[k] - boundMin[k] > maxExtent ) {
00197             splitAxis = k;
00198             maxExtent = boundMax[k] - boundMin[k];
00199         }   
00200     }
00201     uint32_t splitPos = ( start + end ) / 2;
00202     std::nth_element( &buildNodes[start], &buildNodes[splitPos], &buildNodes[end-1] + 1, CompareNode<NodeData>(splitAxis) );
00203     // Allocate kd-tree node and continue recursively
00204     nodes[nodeNum].init( NodeDataTraits<NodeData>::getAxis( *buildNodes[splitPos].first, splitAxis ), splitAxis );
00205     mNodeData[nodeNum] = buildNodes[splitPos];
00206     if( start < splitPos ) {
00207         nodes[nodeNum].hasLeftChild = 1;
00208         uint32_t childNum = nextFreeNode++;
00209         recursiveBuild( childNum, start, splitPos, buildNodes );
00210     }
00211     if( splitPos + 1 < end ) {
00212         nodes[nodeNum].rightChild = nextFreeNode++;
00213         recursiveBuild( nodes[nodeNum].rightChild, splitPos + 1, end, buildNodes );
00214     }
00215 }
00216 
00217 template<typename NodeData, unsigned char K, typename LookupProc>
00218 void KdTree<NodeData, K, LookupProc>::lookup( const NodeData &p, const LookupProc &proc, float maxDist ) const 
00219 {
00220     float maxDistSqrd = maxDist * maxDist;
00221     float pt[K];
00222     for( unsigned char k = 0; k < K; ++k )
00223         pt[k] = NodeDataTraits<NodeData>::getAxis( p, k );
00224 
00225     privateLookup( 0, pt, proc, maxDistSqrd );
00226 }
00227 
00228 template<typename NodeData, unsigned char K, typename LookupProc>
00229 void KdTree<NodeData, K, LookupProc>::privateLookup( uint32_t nodeNum, float p[K], const LookupProc &process, float &maxDistSquared ) const 
00230 {
00231     KdNode<K> *node = &nodes[nodeNum];
00232     // process kd-tree node's children
00233     int axis = node->splitAxis;
00234     if( axis != K ) {
00235         float dist2 = ( p[axis] - node->splitPos ) * ( p[axis] - node->splitPos );
00236         if( p[axis] <= node->splitPos ) {
00237             if(node->hasLeftChild)
00238                 privateLookup( nodeNum + 1, p, process, maxDistSquared );
00239             if( ( dist2 < maxDistSquared ) && ( node->rightChild < nNodes ) )
00240                 privateLookup( node->rightChild, p, process, maxDistSquared );
00241         }
00242         else {
00243             if( node->rightChild < nNodes )
00244                 privateLookup( node->rightChild, p, process, maxDistSquared );
00245             if( ( dist2 < maxDistSquared ) && node->hasLeftChild )
00246                 privateLookup( nodeNum + 1, p, process, maxDistSquared );
00247         }
00248     }
00249     // Hand kd-tree node to processing function
00250     float distSqr = 0.0f;
00251     for( unsigned char k = 0; k < K; ++k ) {
00252         float v = NodeDataTraits<NodeData>::getAxis( *mNodeData[nodeNum].first, k ) - p[k];
00253         distSqr += v * v;
00254     }
00255     
00256     if( distSqr < maxDistSquared )
00257         process.process( mNodeData[nodeNum].second, distSqr, maxDistSquared );
00258 }
00259 
00260 // Find Nearest
00261 template<typename NodeData, unsigned char K, typename LookupProc>
00262 void KdTree<NodeData, K, LookupProc>::findNearest( float p[K], float result[K], uint32_t *resultIndex ) const
00263 {
00264     float maxDist = FLT_MAX;
00265     *resultIndex = -1;
00266     privateFindNearest( 0, p, maxDist, result, resultIndex );
00267 }
00268 
00269 template<typename NodeData, unsigned char K, typename LookupProc>
00270 void KdTree<NodeData, K, LookupProc>::privateFindNearest( uint32_t nodeNum, float p[K], float &maxDistSquared, float result[K], uint32_t *resultIndex ) const
00271 {
00272     KdNode<K> *node = &nodes[nodeNum];
00273     // process kd-tree node's children
00274     int axis = node->splitAxis;
00275     if( axis != K ) {
00276         float dist2 = (p[axis] - node->splitPos) * (p[axis] - node->splitPos);
00277         if( p[axis] <= node->splitPos ) {
00278             if( node->hasLeftChild )
00279                 privateFindNearest( nodeNum+1, p, maxDistSquared, result, resultIndex );
00280             if( ( dist2 < maxDistSquared ) && ( node->rightChild < nNodes) )
00281                 privateFindNearest( node->rightChild, p, maxDistSquared, result, resultIndex );
00282         }
00283         else {
00284             if( node->rightChild < nNodes)
00285                 privateFindNearest(node->rightChild,
00286                               p,
00287                               maxDistSquared, result, resultIndex );
00288             if( dist2 < maxDistSquared && node->hasLeftChild)
00289                 privateFindNearest(nodeNum+1,
00290                               p,
00291                               maxDistSquared, result, resultIndex );
00292         }
00293     }
00294     
00295     float distSqr = NodeDataTraits<NodeData>::distanceSquared( *mNodeData[nodeNum].first, p );
00296     if( distSqr < maxDistSquared ) {
00297         maxDistSquared = distSqr;
00298         for( unsigned char k = 0; k < K; ++k )
00299             result[k] = NodeDataTraits<NodeData>::getAxis( *mNodeData[nodeNum].first, k );
00300         *resultIndex = mNodeData[nodeNum].second;
00301     }
00302 }
00303 
00304 } // namespace ci