2 * This file is part of OpenTTD.
3 * OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2.
4 * OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
5 * See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see <http://www.gnu.org/licenses/>.
8 /** @file kdtree.hpp K-d tree template specialised for 2-dimensional Manhattan geometry */
13 #include "../stdafx.h"
18 * K-dimensional tree, specialised for 2-dimensional space.
19 * This is not intended as a primary storage of data, but as an index into existing data.
20 * Usually the type stored by this tree should be an index into an existing array.
22 * This implementation assumes Manhattan distances are used.
24 * Be careful when using this in game code, depending on usage pattern, the tree shape may
25 * end up different for different clients in multiplayer, causing iteration order to differ
26 * and possibly having elements returned in different order. The using code should be designed
27 * to produce the same result regardless of iteration order.
29 * The element type T must be less-than comparable for FindNearest to work.
31 * @tparam T Type stored in the tree, should be cheap to copy.
32 * @tparam TxyFunc Functor type to extract coordinate from a T value and dimension index (0 or 1).
33 * @tparam CoordT Type of coordinate values extracted via TxyFunc.
34 * @tparam DistT Type to use for representing distance values.
36 template <typename T
, typename TxyFunc
, typename CoordT
, typename DistT
>
38 /** Type of a node in the tree */
40 T element
; ///< Element stored at node
41 size_t left
; ///< Index of node to the left, INVALID_NODE if none
42 size_t right
; ///< Index of node to the right, INVALID_NODE if none
44 node(T element
) : element(element
), left(INVALID_NODE
), right(INVALID_NODE
) { }
47 static const size_t INVALID_NODE
= SIZE_MAX
; ///< Index value indicating no-such-node
48 static const size_t MIN_REBALANCE_THRESHOLD
= 8; ///< Arbitrary value for "not worth rebalancing"
50 std::vector
<node
> nodes
; ///< Pool of all nodes in the tree
51 std::vector
<size_t> free_list
; ///< List of dead indices in the nodes vector
52 size_t root
; ///< Index of root node
53 size_t unbalanced
; ///< Number approximating how unbalanced the tree might be
55 /** Create one new node in the tree, return its index in the pool */
56 size_t AddNode(const T
&element
)
58 if (this->free_list
.empty()) {
59 this->nodes
.emplace_back(element
);
60 return this->nodes
.size() - 1;
62 size_t newidx
= this->free_list
.back();
63 this->free_list
.pop_back();
64 this->nodes
[newidx
] = node
{ element
};
69 /** Find a coordinate value to split a range of elements at */
70 template <typename It
>
71 CoordT
SelectSplitCoord(It begin
, It end
, int level
)
73 It mid
= begin
+ (end
- begin
) / 2;
74 std::nth_element(begin
, mid
, end
, [&](T a
, T b
) { return TxyFunc()(a
, level
% 2) < TxyFunc()(b
, level
% 2); });
75 return TxyFunc()(*mid
, level
% 2);
78 /** Construct a subtree from elements between begin and end iterators, return index of root */
79 template <typename It
>
80 size_t BuildSubtree(It begin
, It end
, int level
)
82 ptrdiff_t count
= end
- begin
;
86 } else if (count
== 1) {
87 return this->AddNode(*begin
);
88 } else if (count
> 1) {
89 CoordT split_coord
= SelectSplitCoord(begin
, end
, level
);
90 It split
= std::partition(begin
, end
, [&](T v
) { return TxyFunc()(v
, level
% 2) < split_coord
; });
91 size_t newidx
= this->AddNode(*split
);
92 this->nodes
[newidx
].left
= this->BuildSubtree(begin
, split
, level
+ 1);
93 this->nodes
[newidx
].right
= this->BuildSubtree(split
+ 1, end
, level
+ 1);
100 /** Rebuild the tree with all existing elements, optionally adding or removing one more */
101 bool Rebuild(const T
*include_element
, const T
*exclude_element
)
103 size_t initial_count
= this->Count();
104 if (initial_count
< MIN_REBALANCE_THRESHOLD
) return false;
106 T root_element
= this->nodes
[this->root
].element
;
107 std::vector
<T
> elements
= this->FreeSubtree(this->root
);
108 elements
.push_back(root_element
);
110 if (include_element
!= nullptr) {
111 elements
.push_back(*include_element
);
114 if (exclude_element
!= nullptr) {
115 typename
std::vector
<T
>::iterator removed
= std::remove(elements
.begin(), elements
.end(), *exclude_element
);
116 elements
.erase(removed
, elements
.end());
120 this->Build(elements
.begin(), elements
.end());
121 dbg_assert(initial_count
== this->Count());
125 /** Insert one element in the tree somewhere below node_idx */
126 void InsertRecursive(const T
&element
, size_t node_idx
, int level
)
128 /* Dimension index of current level */
131 node
&n
= this->nodes
[node_idx
];
133 /* Coordinate of element splitting at this node */
134 CoordT nc
= TxyFunc()(n
.element
, dim
);
135 /* Coordinate of the new element */
136 CoordT ec
= TxyFunc()(element
, dim
);
137 /* Which side to insert on */
138 size_t &next
= (ec
< nc
) ? n
.left
: n
.right
;
140 if (next
== INVALID_NODE
) {
142 size_t newidx
= this->AddNode(element
);
143 /* Vector may have been reallocated at this point, n and next are invalid */
144 node
&nn
= this->nodes
[node_idx
];
145 if (ec
< nc
) nn
.left
= newidx
; else nn
.right
= newidx
;
147 this->InsertRecursive(element
, next
, level
+ 1);
152 * Free all children of the given node
153 * @return Collection of elements that were removed from tree.
155 std::vector
<T
> FreeSubtree(size_t node_idx
)
157 std::vector
<T
> subtree_elements
;
158 node
&n
= this->nodes
[node_idx
];
160 /* We'll be appending items to the free_list, get index of our first item */
161 size_t first_free
= this->free_list
.size();
162 /* Prepare the descent with our children */
163 if (n
.left
!= INVALID_NODE
) this->free_list
.push_back(n
.left
);
164 if (n
.right
!= INVALID_NODE
) this->free_list
.push_back(n
.right
);
165 n
.left
= n
.right
= INVALID_NODE
;
167 /* Recursively free the nodes being collected */
168 for (size_t i
= first_free
; i
< this->free_list
.size(); i
++) {
169 node
&fn
= this->nodes
[this->free_list
[i
]];
170 subtree_elements
.push_back(fn
.element
);
171 if (fn
.left
!= INVALID_NODE
) this->free_list
.push_back(fn
.left
);
172 if (fn
.right
!= INVALID_NODE
) this->free_list
.push_back(fn
.right
);
173 fn
.left
= fn
.right
= INVALID_NODE
;
176 return subtree_elements
;
180 * Find and remove one element from the tree.
181 * @param element The element to search for
182 * @param node_idx Sub-tree to search in
183 * @param level Current depth in the tree
184 * @return New root node index of the sub-tree processed
186 size_t RemoveRecursive(const T
&element
, size_t node_idx
, int level
)
189 node
&n
= this->nodes
[node_idx
];
191 if (n
.element
== element
) {
192 /* Remove this one */
193 this->free_list
.push_back(node_idx
);
194 if (n
.left
== INVALID_NODE
&& n
.right
== INVALID_NODE
) {
195 /* Simple case, leaf, new child node for parent is "none" */
198 /* Complex case, rebuild the sub-tree */
199 std::vector
<T
> subtree_elements
= this->FreeSubtree(node_idx
);
200 return this->BuildSubtree(subtree_elements
.begin(), subtree_elements
.end(), level
);;
203 /* Search in a sub-tree */
204 /* Dimension index of current level */
206 /* Coordinate of element splitting at this node */
207 CoordT nc
= TxyFunc()(n
.element
, dim
);
208 /* Coordinate of the element being removed */
209 CoordT ec
= TxyFunc()(element
, dim
);
210 /* Which side to remove from */
211 size_t next
= (ec
< nc
) ? n
.left
: n
.right
;
212 dbg_assert(next
!= INVALID_NODE
); // node must exist somewhere and must be found before a leaf is reached
214 size_t new_branch
= this->RemoveRecursive(element
, next
, level
+ 1);
215 if (new_branch
!= next
) {
216 /* Vector may have been reallocated at this point, n and next are invalid */
217 node
&nn
= this->nodes
[node_idx
];
218 if (ec
< nc
) nn
.left
= new_branch
; else nn
.right
= new_branch
;
225 DistT
ManhattanDistance(const T
&element
, CoordT x
, CoordT y
) const
227 return abs((DistT
)TxyFunc()(element
, 0) - (DistT
)x
) + abs((DistT
)TxyFunc()(element
, 1) - (DistT
)y
);
230 /** A data element and its distance to a searched-for point */
231 using node_distance
= std::pair
<T
, DistT
>;
232 /** Ordering function for node_distance objects, elements with equal distance are ordered by less-than comparison */
233 static node_distance
SelectNearestNodeDistance(const node_distance
&a
, const node_distance
&b
)
235 if (a
.second
< b
.second
) return a
;
236 if (b
.second
< a
.second
) return b
;
237 if (a
.first
< b
.first
) return a
;
238 if (b
.first
< a
.first
) return b
;
239 NOT_REACHED(); // a.first == b.first: same element must not be inserted twice
241 /** Search a sub-tree for the element nearest to a given point */
242 node_distance
FindNearestRecursive(CoordT xy
[2], size_t node_idx
, int level
, DistT limit
= std::numeric_limits
<DistT
>::max()) const
244 /* Dimension index of current level */
247 const node
&n
= this->nodes
[node_idx
];
249 /* Coordinate of element splitting at this node */
250 CoordT c
= TxyFunc()(n
.element
, dim
);
251 /* This node's distance to target */
252 DistT thisdist
= ManhattanDistance(n
.element
, xy
[0], xy
[1]);
253 /* Assume this node is the best choice for now */
254 node_distance best
= std::make_pair(n
.element
, thisdist
);
256 /* Next node to visit */
257 size_t next
= (xy
[dim
] < c
) ? n
.left
: n
.right
;
258 if (next
!= INVALID_NODE
) {
259 /* Check if there is a better node down the tree */
260 best
= SelectNearestNodeDistance(best
, this->FindNearestRecursive(xy
, next
, level
+ 1));
263 limit
= std::min(best
.second
, limit
);
265 /* Check if the distance from current best is worse than distance from target to splitting line,
266 * if it is we also need to check the other side of the split. */
267 size_t opposite
= (xy
[dim
] >= c
) ? n
.left
: n
.right
; // reverse of above
268 if (opposite
!= INVALID_NODE
&& limit
>= abs((int)xy
[dim
] - (int)c
)) {
269 node_distance other_candidate
= this->FindNearestRecursive(xy
, opposite
, level
+ 1, limit
);
270 best
= SelectNearestNodeDistance(best
, other_candidate
);
276 template <typename Outputter
>
277 void FindContainedRecursive(CoordT p1
[2], CoordT p2
[2], size_t node_idx
, int level
, const Outputter
&outputter
) const
279 /* Dimension index of current level */
282 const node
&n
= this->nodes
[node_idx
];
284 /* Coordinate of element splitting at this node */
285 CoordT ec
= TxyFunc()(n
.element
, dim
);
286 /* Opposite coordinate of element */
287 CoordT oc
= TxyFunc()(n
.element
, 1 - dim
);
289 /* Test if this element is within rectangle */
290 if (ec
>= p1
[dim
] && ec
< p2
[dim
] && oc
>= p1
[1 - dim
] && oc
< p2
[1 - dim
]) outputter(n
.element
);
292 /* Recurse left if part of rectangle is left of split */
293 if (p1
[dim
] < ec
&& n
.left
!= INVALID_NODE
) this->FindContainedRecursive(p1
, p2
, n
.left
, level
+ 1, outputter
);
295 /* Recurse right if part of rectangle is right of split */
296 if (p2
[dim
] > ec
&& n
.right
!= INVALID_NODE
) this->FindContainedRecursive(p1
, p2
, n
.right
, level
+ 1, outputter
);
299 /** Debugging function, counts number of occurrences of an element regardless of its correct position in the tree */
300 size_t CountValue(const T
&element
, size_t node_idx
) const
302 if (node_idx
== INVALID_NODE
) return 0;
303 const node
&n
= this->nodes
[node_idx
];
304 return CountValue(element
, n
.left
) + CountValue(element
, n
.right
) + ((n
.element
== element
) ? 1 : 0);
307 void IncrementUnbalanced(size_t amount
= 1)
309 this->unbalanced
+= amount
;
312 /** Check if the entire tree is in need of rebuilding */
313 bool IsUnbalanced() const
315 size_t count
= this->Count();
316 if (count
< MIN_REBALANCE_THRESHOLD
) return false;
317 return this->unbalanced
> count
/ 4;
320 /** Verify that the invariant is true for a sub-tree, dbg_assert if not */
321 void CheckInvariant(size_t node_idx
, int level
, CoordT min_x
, CoordT max_x
, CoordT min_y
, CoordT max_y
) const
323 if (node_idx
== INVALID_NODE
) return;
325 const node
&n
= this->nodes
[node_idx
];
326 CoordT cx
= TxyFunc()(n
.element
, 0);
327 CoordT cy
= TxyFunc()(n
.element
, 1);
329 dbg_assert(cx
>= min_x
);
330 dbg_assert(cx
< max_x
);
331 dbg_assert(cy
>= min_y
);
332 dbg_assert(cy
< max_y
);
334 if (level
% 2 == 0) {
335 // split in dimension 0 = x
336 CheckInvariant(n
.left
, level
+ 1, min_x
, cx
, min_y
, max_y
);
337 CheckInvariant(n
.right
, level
+ 1, cx
, max_x
, min_y
, max_y
);
339 // split in dimension 1 = y
340 CheckInvariant(n
.left
, level
+ 1, min_x
, max_x
, min_y
, cy
);
341 CheckInvariant(n
.right
, level
+ 1, min_x
, max_x
, cy
, max_y
);
345 /** Verify the invariant for the entire tree, does nothing unless KDTREE_DEBUG is defined */
346 void CheckInvariant() const
349 CheckInvariant(this->root
, 0, std::numeric_limits
<CoordT
>::min(), std::numeric_limits
<CoordT
>::max(), std::numeric_limits
<CoordT
>::min(), std::numeric_limits
<CoordT
>::max());
354 /** Construct a new Kdtree with the given xyfunc */
355 Kdtree() : root(INVALID_NODE
), unbalanced(0) { }
358 * Clear and rebuild the tree from a new sequence of elements,
359 * @tparam It Iterator type for element sequence.
360 * @param begin First element in sequence.
361 * @param end One past last element in sequence.
363 template <typename It
>
364 void Build(It begin
, It end
)
367 this->free_list
.clear();
368 this->unbalanced
= 0;
369 if (begin
== end
) return;
370 this->nodes
.reserve(end
- begin
);
372 this->root
= this->BuildSubtree(begin
, end
, 0);
382 this->free_list
.clear();
383 this->unbalanced
= 0;
388 * Reconstruct the tree with the same elements, letting it be fully balanced.
392 this->Rebuild(nullptr, nullptr);
396 * Insert a single element in the tree.
397 * Repeatedly inserting single elements may cause the tree to become unbalanced.
398 * Undefined behaviour if the element already exists in the tree.
400 void Insert(const T
&element
)
402 if (this->Count() == 0) {
403 this->root
= this->AddNode(element
);
405 if (!this->IsUnbalanced() || !this->Rebuild(&element
, nullptr)) {
406 this->InsertRecursive(element
, this->root
, 0);
407 this->IncrementUnbalanced();
414 * Remove a single element from the tree, if it exists.
415 * Since elements are stored in interior nodes as well as leaf nodes, removing one may
416 * require a larger sub-tree to be re-built. Because of this, worst case run time is
417 * as bad as a full tree rebuild.
419 void Remove(const T
&element
)
421 size_t count
= this->Count();
422 if (count
== 0) return;
423 if (!this->IsUnbalanced() || !this->Rebuild(nullptr, &element
)) {
424 /* If the removed element is the root node, this modifies this->root */
425 this->root
= this->RemoveRecursive(element
, this->root
, 0);
426 this->IncrementUnbalanced();
431 /** Get number of elements stored in tree */
434 dbg_assert(this->free_list
.size() <= this->nodes
.size());
435 return this->nodes
.size() - this->free_list
.size();
439 * Find the element closest to given coordinate, in Manhattan distance.
440 * For multiple elements with the same distance, the one comparing smaller with
441 * a less-than comparison is chosen.
443 T
FindNearest(CoordT x
, CoordT y
) const
445 dbg_assert(this->Count() > 0);
447 CoordT xy
[2] = { x
, y
};
448 return this->FindNearestRecursive(xy
, this->root
, 0).first
;
452 * Find all items contained within the given rectangle.
453 * @note Start coordinates are inclusive, end coordinates are exclusive. x1<x2 && y1<y2 is a precondition.
454 * @param x1 Start first coordinate, points found are greater or equals to this.
455 * @param y1 Start second coordinate, points found are greater or equals to this.
456 * @param x2 End first coordinate, points found are less than this.
457 * @param y2 End second coordinate, points found are less than this.
458 * @param outputter Callback used to return values from the search.
460 template <typename Outputter
>
461 void FindContained(CoordT x1
, CoordT y1
, CoordT x2
, CoordT y2
, const Outputter
&outputter
) const
466 if (this->Count() == 0) return;
468 CoordT p1
[2] = { x1
, y1
};
469 CoordT p2
[2] = { x2
, y2
};
470 this->FindContainedRecursive(p1
, p2
, this->root
, 0, outputter
);
474 * Find all items contained within the given rectangle.
475 * @note End coordinates are exclusive, x1<x2 && y1<y2 is a precondition.
477 std::vector
<T
> FindContained(CoordT x1
, CoordT y1
, CoordT x2
, CoordT y2
) const
479 std::vector
<T
> result
;
480 this->FindContained(x1
, y1
, x2
, y2
, [&result
](T e
) {result
.push_back(e
); });