00001 /* 00002 Copyright (c) 2009, The Barbarian Group 00003 All rights reserved. 00004 00005 Redistribution and use in source and binary forms, with or without modification, are permitted provided that 00006 the following conditions are met: 00007 00008 * Redistributions of source code must retain the above copyright notice, this list of conditions and 00009 the following disclaimer. 00010 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and 00011 the following disclaimer in the documentation and/or other materials provided with the distribution. 00012 00013 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED 00014 WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 00015 PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 00016 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00017 TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 00018 HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00019 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 00020 POSSIBILITY OF SUCH DAMAGE. 00021 */ 00022 00023 #pragma once 00024 00025 #include "cinder/Cinder.h" 00026 #include "cinder/Vector.h" 00027 00028 #include <vector> 00029 #include <float.h> 00030 #include <stdlib.h> 00031 #include <algorithm> 00032 #include <utility> 00033 00034 namespace cinder { 00035 00036 // KdTree Declarations 00037 template<unsigned char K> 00038 struct KdNode { 00039 void init( float p, uint32_t a) { 00040 splitPos = p; 00041 splitAxis = a; 00042 rightChild = ~0; 00043 hasLeftChild = 0; 00044 } 00045 void initLeaf() { 00046 splitAxis = K; 00047 rightChild = ~0; 00048 hasLeftChild = 0; 00049 } 00050 // KdNode Data 00051 float splitPos; 00052 uint32_t splitAxis:2; 00053 uint32_t hasLeftChild:1; 00054 uint32_t rightChild:29; 00055 }; 00056 00057 struct NullLookupProc { 00058 public: 00059 void process( uint32_t id, float distSqrd, float &maxDistSqrd ) {} 00060 }; 00061 00062 template <typename NodeData, unsigned char K=3, class LookupProc = NullLookupProc> class KdTree { 00063 public: 00064 typedef std::pair<const NodeData*, uint32_t> NodeDataIndex; 00065 00066 // KdTree Public Methods 00067 template<typename NodeDataVector> 00068 KdTree( const NodeDataVector &data ); 00069 KdTree() {} 00070 template<typename NodeDataVector> 00071 void initialize( const NodeDataVector &d ); 00072 ~KdTree() { 00073 free( nodes ); 00074 delete[] mNodeData; 00075 } 00076 void recursiveBuild( uint32_t nodeNum, uint32_t start, uint32_t end, std::vector<NodeDataIndex> &buildNodes ); 00077 void lookup( const NodeData &p, const LookupProc &process, float maxDist ) const; 00078 void findNearest( float p[K], float result[K], uint32_t *resultIndex ) const; 00079 00080 private: 00081 // KdTree Private Methods 00082 void privateLookup(uint32_t nodeNum, float p[K], const LookupProc &process, float &maxDistSquared) const; 00083 void privateFindNearest( uint32_t nodeNum, float p[K], float &maxDistSquared, float result[K], uint32_t *resultIndex ) const; 00084 // KdTree Private Data 00085 KdNode<K> *nodes; 00086 NodeDataIndex *mNodeData; 00087 uint32_t nNodes, nextFreeNode; 00088 }; 00089 00090 00091 // Shims 00092 template<typename NDV> 00093 struct NodeDataVectorTraits 00094 { 00095 static uint32_t getSize( const NDV &ndv ) { 00096 return static_cast<uint32_t>( ndv.size() ); 00097 } 00098 }; 00099 00100 template<typename NodeData> 00101 struct NodeDataTraits 00102 { 00103 static float getAxis( const NodeData &data, int axis ) { 00104 if( axis == 0 ) return data.x; 00105 else if( axis == 1 ) return data.y; 00106 else return (float)data.z; 00107 } 00108 static float getAxis0( const NodeData &data ) { return static_cast<float>( data.x ); } 00109 static float getAxis1( const NodeData &data ) { return static_cast<float>( data.y ); } 00110 static float getAxis2( const NodeData &data ) { return static_cast<float>( data.z ); } 00111 static float distanceSquared( const NodeData &data, float k[3] ) { 00112 float result = ( data.x - k[0] ) * ( data.x - k[0] ); 00113 result += ( data.y - k[1] ) * ( data.y - k[1] ); 00114 result += ( data.z - k[2] ) * ( data.z - k[2] ); 00115 return result; 00116 } 00117 }; 00118 00119 template<> 00120 struct NodeDataTraits<Vec2f> 00121 { 00122 static float getAxis( const Vec2f &data, int axis ) { 00123 if( axis == 0 ) return data.x; 00124 else return data.y; 00125 } 00126 static float getAxis0( const Vec2f &data ) { return static_cast<float>( data.x ); } 00127 static float getAxis1( const Vec2f &data ) { return static_cast<float>( data.y ); } 00128 static float distanceSquared( const Vec2f &data, float k[2] ) { 00129 float result = ( data.x - k[0] ) * ( data.x - k[0] ); 00130 result += ( data.y - k[1] ) * ( data.y - k[1] ); 00131 return result; 00132 } 00133 }; 00134 00135 template<typename NodeData> struct CompareNode { 00136 CompareNode( int a ) { axis = a; } 00137 int axis; 00138 bool operator()(const std::pair<const NodeData*,uint32_t> &d1, 00139 const std::pair<const NodeData*,uint32_t> &d2) const { 00140 return NodeDataTraits<NodeData>::getAxis( *d1.first, axis ) == NodeDataTraits<NodeData>::getAxis( *d2.first, axis ) ? ( d1.first < d2.first ) : 00141 NodeDataTraits<NodeData>::getAxis( *d1.first, axis ) < NodeDataTraits<NodeData>::getAxis( *d2.first, axis ); 00142 } 00143 }; 00144 00145 // KdTree Method Definitions 00146 template<typename NodeData, unsigned char K, typename LookupProc> 00147 template<typename NodeDataVector> 00148 KdTree<NodeData, K, LookupProc>::KdTree(const NodeDataVector &d) 00149 { 00150 initialize( d ); 00151 } 00152 00153 template<typename NodeData, unsigned char K, typename LookupProc> 00154 template<typename NodeDataVector> 00155 void KdTree<NodeData, K, LookupProc>::initialize( const NodeDataVector &d ) 00156 { 00157 nNodes = NodeDataVectorTraits<NodeDataVector>::getSize( d ); 00158 nextFreeNode = 1; 00159 nodes = (KdNode<K> *)malloc(nNodes * sizeof(KdNode<K>)); 00160 mNodeData = new NodeDataIndex[nNodes]; 00161 std::vector<NodeDataIndex> buildNodes; 00162 buildNodes.reserve( nNodes ); 00163 for( uint32_t i = 0; i < nNodes; ++i ) 00164 buildNodes.push_back( std::make_pair( &d[i], i ) ); 00165 // Begin the KdTree building process 00166 recursiveBuild( 0, 0, nNodes, buildNodes ); 00167 } 00168 00169 template<typename NodeData, unsigned char K, typename LookupProc> 00170 void KdTree<NodeData, K, LookupProc>::recursiveBuild( uint32_t nodeNum, uint32_t start, uint32_t end, std::vector<NodeDataIndex> &buildNodes ) 00171 { 00172 // Create leaf node of kd-tree if we've reached the bottom 00173 if( start + 1 == end) { 00174 nodes[nodeNum].initLeaf(); 00175 mNodeData[nodeNum] = buildNodes[start]; 00176 return; 00177 } 00178 // Choose split direction and partition data 00179 // Compute bounds of data from _start_ to _end_ 00180 float boundMin[K], boundMax[K]; 00181 for( unsigned char k = 0; k < K; ++k ) { 00182 boundMin[k] = FLT_MAX; 00183 boundMax[k] = FLT_MIN; 00184 } 00185 00186 for( uint32_t i = start; i < end; ++i ) { 00187 for( uint8_t axis = 0; axis < K; axis++ ) { 00188 // NOT Compiling? you should define NOMINMAX 00189 boundMin[axis] = std::min( boundMin[axis], NodeDataTraits<NodeData>::getAxis( *buildNodes[i].first, axis ) ); 00190 boundMax[axis] = std::max( boundMax[axis], NodeDataTraits<NodeData>::getAxis( *buildNodes[i].first, axis ) ); 00191 } 00192 } 00193 int splitAxis = 0; 00194 float maxExtent = boundMax[0] - boundMin[0]; 00195 for( unsigned char k = 1; k < K; ++k ) { 00196 if( boundMax[k] - boundMin[k] > maxExtent ) { 00197 splitAxis = k; 00198 maxExtent = boundMax[k] - boundMin[k]; 00199 } 00200 } 00201 uint32_t splitPos = ( start + end ) / 2; 00202 std::nth_element( &buildNodes[start], &buildNodes[splitPos], &buildNodes[end-1] + 1, CompareNode<NodeData>(splitAxis) ); 00203 // Allocate kd-tree node and continue recursively 00204 nodes[nodeNum].init( NodeDataTraits<NodeData>::getAxis( *buildNodes[splitPos].first, splitAxis ), splitAxis ); 00205 mNodeData[nodeNum] = buildNodes[splitPos]; 00206 if( start < splitPos ) { 00207 nodes[nodeNum].hasLeftChild = 1; 00208 uint32_t childNum = nextFreeNode++; 00209 recursiveBuild( childNum, start, splitPos, buildNodes ); 00210 } 00211 if( splitPos + 1 < end ) { 00212 nodes[nodeNum].rightChild = nextFreeNode++; 00213 recursiveBuild( nodes[nodeNum].rightChild, splitPos + 1, end, buildNodes ); 00214 } 00215 } 00216 00217 template<typename NodeData, unsigned char K, typename LookupProc> 00218 void KdTree<NodeData, K, LookupProc>::lookup( const NodeData &p, const LookupProc &proc, float maxDist ) const 00219 { 00220 float maxDistSqrd = maxDist * maxDist; 00221 float pt[K]; 00222 for( unsigned char k = 0; k < K; ++k ) 00223 pt[k] = NodeDataTraits<NodeData>::getAxis( p, k ); 00224 00225 privateLookup( 0, pt, proc, maxDistSqrd ); 00226 } 00227 00228 template<typename NodeData, unsigned char K, typename LookupProc> 00229 void KdTree<NodeData, K, LookupProc>::privateLookup( uint32_t nodeNum, float p[K], const LookupProc &process, float &maxDistSquared ) const 00230 { 00231 KdNode<K> *node = &nodes[nodeNum]; 00232 // process kd-tree node's children 00233 int axis = node->splitAxis; 00234 if( axis != K ) { 00235 float dist2 = ( p[axis] - node->splitPos ) * ( p[axis] - node->splitPos ); 00236 if( p[axis] <= node->splitPos ) { 00237 if(node->hasLeftChild) 00238 privateLookup( nodeNum + 1, p, process, maxDistSquared ); 00239 if( ( dist2 < maxDistSquared ) && ( node->rightChild < nNodes ) ) 00240 privateLookup( node->rightChild, p, process, maxDistSquared ); 00241 } 00242 else { 00243 if( node->rightChild < nNodes ) 00244 privateLookup( node->rightChild, p, process, maxDistSquared ); 00245 if( ( dist2 < maxDistSquared ) && node->hasLeftChild ) 00246 privateLookup( nodeNum + 1, p, process, maxDistSquared ); 00247 } 00248 } 00249 // Hand kd-tree node to processing function 00250 float distSqr = 0.0f; 00251 for( unsigned char k = 0; k < K; ++k ) { 00252 float v = NodeDataTraits<NodeData>::getAxis( *mNodeData[nodeNum].first, k ) - p[k]; 00253 distSqr += v * v; 00254 } 00255 00256 if( distSqr < maxDistSquared ) 00257 process.process( mNodeData[nodeNum].second, distSqr, maxDistSquared ); 00258 } 00259 00260 // Find Nearest 00261 template<typename NodeData, unsigned char K, typename LookupProc> 00262 void KdTree<NodeData, K, LookupProc>::findNearest( float p[K], float result[K], uint32_t *resultIndex ) const 00263 { 00264 float maxDist = FLT_MAX; 00265 *resultIndex = -1; 00266 privateFindNearest( 0, p, maxDist, result, resultIndex ); 00267 } 00268 00269 template<typename NodeData, unsigned char K, typename LookupProc> 00270 void KdTree<NodeData, K, LookupProc>::privateFindNearest( uint32_t nodeNum, float p[K], float &maxDistSquared, float result[K], uint32_t *resultIndex ) const 00271 { 00272 KdNode<K> *node = &nodes[nodeNum]; 00273 // process kd-tree node's children 00274 int axis = node->splitAxis; 00275 if( axis != K ) { 00276 float dist2 = (p[axis] - node->splitPos) * (p[axis] - node->splitPos); 00277 if( p[axis] <= node->splitPos ) { 00278 if( node->hasLeftChild ) 00279 privateFindNearest( nodeNum+1, p, maxDistSquared, result, resultIndex ); 00280 if( ( dist2 < maxDistSquared ) && ( node->rightChild < nNodes) ) 00281 privateFindNearest( node->rightChild, p, maxDistSquared, result, resultIndex ); 00282 } 00283 else { 00284 if( node->rightChild < nNodes) 00285 privateFindNearest(node->rightChild, 00286 p, 00287 maxDistSquared, result, resultIndex ); 00288 if( dist2 < maxDistSquared && node->hasLeftChild) 00289 privateFindNearest(nodeNum+1, 00290 p, 00291 maxDistSquared, result, resultIndex ); 00292 } 00293 } 00294 00295 float distSqr = NodeDataTraits<NodeData>::distanceSquared( *mNodeData[nodeNum].first, p ); 00296 if( distSqr < maxDistSquared ) { 00297 maxDistSquared = distSqr; 00298 for( unsigned char k = 0; k < K; ++k ) 00299 result[k] = NodeDataTraits<NodeData>::getAxis( *mNodeData[nodeNum].first, k ); 00300 *resultIndex = mNodeData[nodeNum].second; 00301 } 00302 } 00303 00304 } // namespace ci