2 * This file is part of OpenTTD.
3 * OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2.
4 * OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
5 * See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see <http://www.gnu.org/licenses/>.
8 /** @file kdtree.hpp K-d tree template specialised for 2-dimensional Manhattan geometry */
13 #include "../stdafx.h"
16 * K-dimensional tree, specialised for 2-dimensional space.
17 * This is not intended as a primary storage of data, but as an index into existing data.
18 * Usually the type stored by this tree should be an index into an existing array.
20 * This implementation assumes Manhattan distances are used.
22 * Be careful when using this in game code, depending on usage pattern, the tree shape may
23 * end up different for different clients in multiplayer, causing iteration order to differ
24 * and possibly having elements returned in different order. The using code should be designed
25 * to produce the same result regardless of iteration order.
27 * The element type T must be less-than comparable for FindNearest to work.
29 * @tparam T Type stored in the tree, should be cheap to copy.
30 * @tparam TxyFunc Functor type to extract coordinate from a T value and dimension index (0 or 1).
31 * @tparam CoordT Type of coordinate values extracted via TxyFunc.
32 * @tparam DistT Type to use for representing distance values.
34 template <typename T
, typename TxyFunc
, typename CoordT
, typename DistT
>
36 /** Type of a node in the tree */
38 T element
; ///< Element stored at node
39 size_t left
; ///< Index of node to the left, INVALID_NODE if none
40 size_t right
; ///< Index of node to the right, INVALID_NODE if none
42 node(T element
) : element(element
), left(INVALID_NODE
), right(INVALID_NODE
) { }
45 static const size_t INVALID_NODE
= SIZE_MAX
; ///< Index value indicating no-such-node
47 std::vector
<node
> nodes
; ///< Pool of all nodes in the tree
48 std::vector
<size_t> free_list
; ///< List of dead indices in the nodes vector
49 size_t root
; ///< Index of root node
50 TxyFunc xyfunc
; ///< Functor to extract a coordinate from an element
51 size_t unbalanced
; ///< Number approximating how unbalanced the tree might be
53 /** Create one new node in the tree, return its index in the pool */
54 size_t AddNode(const T
&element
)
56 if (this->free_list
.empty()) {
57 this->nodes
.emplace_back(element
);
58 return this->nodes
.size() - 1;
60 size_t newidx
= this->free_list
.back();
61 this->free_list
.pop_back();
62 this->nodes
[newidx
] = node
{ element
};
67 /** Find a coordinate value to split a range of elements at */
68 template <typename It
>
69 CoordT
SelectSplitCoord(It begin
, It end
, int level
)
71 It mid
= begin
+ (end
- begin
) / 2;
72 std::nth_element(begin
, mid
, end
, [&](T a
, T b
) { return this->xyfunc(a
, level
% 2) < this->xyfunc(b
, level
% 2); });
73 return this->xyfunc(*mid
, level
% 2);
76 /** Construct a subtree from elements between begin and end iterators, return index of root */
77 template <typename It
>
78 size_t BuildSubtree(It begin
, It end
, int level
)
80 ptrdiff_t count
= end
- begin
;
84 } else if (count
== 1) {
85 return this->AddNode(*begin
);
86 } else if (count
> 1) {
87 CoordT split_coord
= SelectSplitCoord(begin
, end
, level
);
88 It split
= std::partition(begin
, end
, [&](T v
) { return this->xyfunc(v
, level
% 2) < split_coord
; });
89 size_t newidx
= this->AddNode(*split
);
90 this->nodes
[newidx
].left
= this->BuildSubtree(begin
, split
, level
+ 1);
91 this->nodes
[newidx
].right
= this->BuildSubtree(split
+ 1, end
, level
+ 1);
98 /** Rebuild the tree with all existing elements, optionally adding or removing one more */
99 bool Rebuild(const T
*include_element
, const T
*exclude_element
)
101 size_t initial_count
= this->Count();
102 if (initial_count
< 8) return false; // arbitrary value for "not worth rebalancing"
104 T root_element
= this->nodes
[this->root
].element
;
105 std::vector
<T
> elements
= this->FreeSubtree(this->root
);
106 elements
.push_back(root_element
);
108 if (include_element
!= nullptr) {
109 elements
.push_back(*include_element
);
112 if (exclude_element
!= nullptr) {
113 typename
std::vector
<T
>::iterator removed
= std::remove(elements
.begin(), elements
.end(), *exclude_element
);
114 elements
.erase(removed
, elements
.end());
118 this->Build(elements
.begin(), elements
.end());
119 assert(initial_count
== this->Count());
123 /** Insert one element in the tree somewhere below node_idx */
124 void InsertRecursive(const T
&element
, size_t node_idx
, int level
)
126 /* Dimension index of current level */
129 node
&n
= this->nodes
[node_idx
];
131 /* Coordinate of element splitting at this node */
132 CoordT nc
= this->xyfunc(n
.element
, dim
);
133 /* Coordinate of the new element */
134 CoordT ec
= this->xyfunc(element
, dim
);
135 /* Which side to insert on */
136 size_t &next
= (ec
< nc
) ? n
.left
: n
.right
;
138 if (next
== INVALID_NODE
) {
140 size_t newidx
= this->AddNode(element
);
141 /* Vector may have been reallocated at this point, n and next are invalid */
142 node
&nn
= this->nodes
[node_idx
];
143 if (ec
< nc
) nn
.left
= newidx
; else nn
.right
= newidx
;
145 this->InsertRecursive(element
, next
, level
+ 1);
150 * Free all children of the given node
151 * @return Collection of elements that were removed from tree.
153 std::vector
<T
> FreeSubtree(size_t node_idx
)
155 std::vector
<T
> subtree_elements
;
156 node
&n
= this->nodes
[node_idx
];
158 /* We'll be appending items to the free_list, get index of our first item */
159 size_t first_free
= this->free_list
.size();
160 /* Prepare the descent with our children */
161 if (n
.left
!= INVALID_NODE
) this->free_list
.push_back(n
.left
);
162 if (n
.right
!= INVALID_NODE
) this->free_list
.push_back(n
.right
);
163 n
.left
= n
.right
= INVALID_NODE
;
165 /* Recursively free the nodes being collected */
166 for (size_t i
= first_free
; i
< this->free_list
.size(); i
++) {
167 node
&fn
= this->nodes
[this->free_list
[i
]];
168 subtree_elements
.push_back(fn
.element
);
169 if (fn
.left
!= INVALID_NODE
) this->free_list
.push_back(fn
.left
);
170 if (fn
.right
!= INVALID_NODE
) this->free_list
.push_back(fn
.right
);
171 fn
.left
= fn
.right
= INVALID_NODE
;
174 return subtree_elements
;
178 * Find and remove one element from the tree.
179 * @param element The element to search for
180 * @param node_idx Sub-tree to search in
181 * @param level Current depth in the tree
182 * @return New root node index of the sub-tree processed
184 size_t RemoveRecursive(const T
&element
, size_t node_idx
, int level
)
187 node
&n
= this->nodes
[node_idx
];
189 if (n
.element
== element
) {
190 /* Remove this one */
191 this->free_list
.push_back(node_idx
);
192 if (n
.left
== INVALID_NODE
&& n
.right
== INVALID_NODE
) {
193 /* Simple case, leaf, new child node for parent is "none" */
196 /* Complex case, rebuild the sub-tree */
197 std::vector
<T
> subtree_elements
= this->FreeSubtree(node_idx
);
198 return this->BuildSubtree(subtree_elements
.begin(), subtree_elements
.end(), level
);;
201 /* Search in a sub-tree */
202 /* Dimension index of current level */
204 /* Coordinate of element splitting at this node */
205 CoordT nc
= this->xyfunc(n
.element
, dim
);
206 /* Coordinate of the element being removed */
207 CoordT ec
= this->xyfunc(element
, dim
);
208 /* Which side to remove from */
209 size_t next
= (ec
< nc
) ? n
.left
: n
.right
;
210 assert(next
!= INVALID_NODE
); // node must exist somewhere and must be found before a leaf is reached
212 size_t new_branch
= this->RemoveRecursive(element
, next
, level
+ 1);
213 if (new_branch
!= next
) {
214 /* Vector may have been reallocated at this point, n and next are invalid */
215 node
&nn
= this->nodes
[node_idx
];
216 if (ec
< nc
) nn
.left
= new_branch
; else nn
.right
= new_branch
;
223 DistT
ManhattanDistance(const T
&element
, CoordT x
, CoordT y
) const
225 return abs((DistT
)this->xyfunc(element
, 0) - (DistT
)x
) + abs((DistT
)this->xyfunc(element
, 1) - (DistT
)y
);
228 /** A data element and its distance to a searched-for point */
229 using node_distance
= std::pair
<T
, DistT
>;
230 /** Ordering function for node_distance objects, elements with equal distance are ordered by less-than comparison */
231 static node_distance
SelectNearestNodeDistance(const node_distance
&a
, const node_distance
&b
)
233 if (a
.second
< b
.second
) return a
;
234 if (b
.second
< a
.second
) return b
;
235 if (a
.first
< b
.first
) return a
;
236 if (b
.first
< a
.first
) return b
;
237 NOT_REACHED(); // a.first == b.first: same element must not be inserted twice
239 /** Search a sub-tree for the element nearest to a given point */
240 node_distance
FindNearestRecursive(CoordT xy
[2], size_t node_idx
, int level
, DistT limit
= std::numeric_limits
<DistT
>::max()) const
242 /* Dimension index of current level */
245 const node
&n
= this->nodes
[node_idx
];
247 /* Coordinate of element splitting at this node */
248 CoordT c
= this->xyfunc(n
.element
, dim
);
249 /* This node's distance to target */
250 DistT thisdist
= ManhattanDistance(n
.element
, xy
[0], xy
[1]);
251 /* Assume this node is the best choice for now */
252 node_distance best
= std::make_pair(n
.element
, thisdist
);
254 /* Next node to visit */
255 size_t next
= (xy
[dim
] < c
) ? n
.left
: n
.right
;
256 if (next
!= INVALID_NODE
) {
257 /* Check if there is a better node down the tree */
258 best
= SelectNearestNodeDistance(best
, this->FindNearestRecursive(xy
, next
, level
+ 1));
261 limit
= std::min(best
.second
, limit
);
263 /* Check if the distance from current best is worse than distance from target to splitting line,
264 * if it is we also need to check the other side of the split. */
265 size_t opposite
= (xy
[dim
] >= c
) ? n
.left
: n
.right
; // reverse of above
266 if (opposite
!= INVALID_NODE
&& limit
>= abs((int)xy
[dim
] - (int)c
)) {
267 node_distance other_candidate
= this->FindNearestRecursive(xy
, opposite
, level
+ 1, limit
);
268 best
= SelectNearestNodeDistance(best
, other_candidate
);
274 template <typename Outputter
>
275 void FindContainedRecursive(CoordT p1
[2], CoordT p2
[2], size_t node_idx
, int level
, const Outputter
&outputter
) const
277 /* Dimension index of current level */
280 const node
&n
= this->nodes
[node_idx
];
282 /* Coordinate of element splitting at this node */
283 CoordT ec
= this->xyfunc(n
.element
, dim
);
284 /* Opposite coordinate of element */
285 CoordT oc
= this->xyfunc(n
.element
, 1 - dim
);
287 /* Test if this element is within rectangle */
288 if (ec
>= p1
[dim
] && ec
< p2
[dim
] && oc
>= p1
[1 - dim
] && oc
< p2
[1 - dim
]) outputter(n
.element
);
290 /* Recurse left if part of rectangle is left of split */
291 if (p1
[dim
] < ec
&& n
.left
!= INVALID_NODE
) this->FindContainedRecursive(p1
, p2
, n
.left
, level
+ 1, outputter
);
293 /* Recurse right if part of rectangle is right of split */
294 if (p2
[dim
] > ec
&& n
.right
!= INVALID_NODE
) this->FindContainedRecursive(p1
, p2
, n
.right
, level
+ 1, outputter
);
297 /** Debugging function, counts number of occurrences of an element regardless of its correct position in the tree */
298 size_t CountValue(const T
&element
, size_t node_idx
) const
300 if (node_idx
== INVALID_NODE
) return 0;
301 const node
&n
= this->nodes
[node_idx
];
302 return CountValue(element
, n
.left
) + CountValue(element
, n
.right
) + ((n
.element
== element
) ? 1 : 0);
305 void IncrementUnbalanced(size_t amount
= 1)
307 this->unbalanced
+= amount
;
310 /** Check if the entire tree is in need of rebuilding */
313 size_t count
= this->Count();
314 if (count
< 8) return false;
315 return this->unbalanced
> this->Count() / 4;
318 /** Verify that the invariant is true for a sub-tree, assert if not */
319 void CheckInvariant(size_t node_idx
, int level
, CoordT min_x
, CoordT max_x
, CoordT min_y
, CoordT max_y
)
321 if (node_idx
== INVALID_NODE
) return;
323 const node
&n
= this->nodes
[node_idx
];
324 CoordT cx
= this->xyfunc(n
.element
, 0);
325 CoordT cy
= this->xyfunc(n
.element
, 1);
332 if (level
% 2 == 0) {
333 // split in dimension 0 = x
334 CheckInvariant(n
.left
, level
+ 1, min_x
, cx
, min_y
, max_y
);
335 CheckInvariant(n
.right
, level
+ 1, cx
, max_x
, min_y
, max_y
);
337 // split in dimension 1 = y
338 CheckInvariant(n
.left
, level
+ 1, min_x
, max_x
, min_y
, cy
);
339 CheckInvariant(n
.right
, level
+ 1, min_x
, max_x
, cy
, max_y
);
343 /** Verify the invariant for the entire tree, does nothing unless KDTREE_DEBUG is defined */
344 void CheckInvariant()
347 CheckInvariant(this->root
, 0, std::numeric_limits
<CoordT
>::min(), std::numeric_limits
<CoordT
>::max(), std::numeric_limits
<CoordT
>::min(), std::numeric_limits
<CoordT
>::max());
352 /** Construct a new Kdtree with the given xyfunc */
353 Kdtree(TxyFunc xyfunc
) : root(INVALID_NODE
), xyfunc(xyfunc
), unbalanced(0) { }
356 * Clear and rebuild the tree from a new sequence of elements,
357 * @tparam It Iterator type for element sequence.
358 * @param begin First element in sequence.
359 * @param end One past last element in sequence.
361 template <typename It
>
362 void Build(It begin
, It end
)
365 this->free_list
.clear();
366 this->unbalanced
= 0;
367 if (begin
== end
) return;
368 this->nodes
.reserve(end
- begin
);
370 this->root
= this->BuildSubtree(begin
, end
, 0);
380 this->free_list
.clear();
381 this->unbalanced
= 0;
386 * Reconstruct the tree with the same elements, letting it be fully balanced.
390 this->Rebuild(nullptr, nullptr);
394 * Insert a single element in the tree.
395 * Repeatedly inserting single elements may cause the tree to become unbalanced.
396 * Undefined behaviour if the element already exists in the tree.
398 void Insert(const T
&element
)
400 if (this->Count() == 0) {
401 this->root
= this->AddNode(element
);
403 if (!this->IsUnbalanced() || !this->Rebuild(&element
, nullptr)) {
404 this->InsertRecursive(element
, this->root
, 0);
405 this->IncrementUnbalanced();
412 * Remove a single element from the tree, if it exists.
413 * Since elements are stored in interior nodes as well as leaf nodes, removing one may
414 * require a larger sub-tree to be re-built. Because of this, worst case run time is
415 * as bad as a full tree rebuild.
417 void Remove(const T
&element
)
419 size_t count
= this->Count();
420 if (count
== 0) return;
421 if (!this->IsUnbalanced() || !this->Rebuild(nullptr, &element
)) {
422 /* If the removed element is the root node, this modifies this->root */
423 this->root
= this->RemoveRecursive(element
, this->root
, 0);
424 this->IncrementUnbalanced();
429 /** Get number of elements stored in tree */
432 assert(this->free_list
.size() <= this->nodes
.size());
433 return this->nodes
.size() - this->free_list
.size();
437 * Find the element closest to given coordinate, in Manhattan distance.
438 * For multiple elements with the same distance, the one comparing smaller with
439 * a less-than comparison is chosen.
441 T
FindNearest(CoordT x
, CoordT y
) const
443 assert(this->Count() > 0);
445 CoordT xy
[2] = { x
, y
};
446 return this->FindNearestRecursive(xy
, this->root
, 0).first
;
450 * Find all items contained within the given rectangle.
451 * @note Start coordinates are inclusive, end coordinates are exclusive. x1<x2 && y1<y2 is a precondition.
452 * @param x1 Start first coordinate, points found are greater or equals to this.
453 * @param y1 Start second coordinate, points found are greater or equals to this.
454 * @param x2 End first coordinate, points found are less than this.
455 * @param y2 End second coordinate, points found are less than this.
456 * @param outputter Callback used to return values from the search.
458 template <typename Outputter
>
459 void FindContained(CoordT x1
, CoordT y1
, CoordT x2
, CoordT y2
, const Outputter
&outputter
) const
464 if (this->Count() == 0) return;
466 CoordT p1
[2] = { x1
, y1
};
467 CoordT p2
[2] = { x2
, y2
};
468 this->FindContainedRecursive(p1
, p2
, this->root
, 0, outputter
);
472 * Find all items contained within the given rectangle.
473 * @note End coordinates are exclusive, x1<x2 && y1<y2 is a precondition.
475 std::vector
<T
> FindContained(CoordT x1
, CoordT y1
, CoordT x2
, CoordT y2
) const
477 std::vector
<T
> result
;
478 this->FindContained(x1
, y1
, x2
, y2
, [&result
](T e
) {result
.push_back(e
); });