1 //===- TrieRawHashMap.cpp -------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/ADT/TrieRawHashMap.h"
10 #include "llvm/ADT/LazyAtomicPointer.h"
11 #include "llvm/ADT/StringExtras.h"
12 #include "llvm/ADT/TrieHashIndexGenerator.h"
13 #include "llvm/Support/Allocator.h"
14 #include "llvm/Support/Casting.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/ThreadSafeAllocator.h"
17 #include "llvm/Support/TrailingObjects.h"
18 #include "llvm/Support/raw_ostream.h"
25 const bool IsSubtrie
= false;
27 TrieNode(bool IsSubtrie
) : IsSubtrie(IsSubtrie
) {}
29 static void *operator new(size_t Size
) { return ::operator new(Size
); }
30 void operator delete(void *Ptr
) { ::operator delete(Ptr
); }
33 struct TrieContent final
: public TrieNode
{
34 const uint8_t ContentOffset
;
35 const uint8_t HashSize
;
36 const uint8_t HashOffset
;
38 void *getValuePointer() const {
39 auto *Content
= reinterpret_cast<const uint8_t *>(this) + ContentOffset
;
40 return const_cast<uint8_t *>(Content
);
43 ArrayRef
<uint8_t> getHash() const {
44 auto *Begin
= reinterpret_cast<const uint8_t *>(this) + HashOffset
;
45 return ArrayRef(Begin
, Begin
+ HashSize
);
48 TrieContent(size_t ContentOffset
, size_t HashSize
, size_t HashOffset
)
49 : TrieNode(/*IsSubtrie=*/false), ContentOffset(ContentOffset
),
50 HashSize(HashSize
), HashOffset(HashOffset
) {}
52 static bool classof(const TrieNode
*TN
) { return !TN
->IsSubtrie
; }
55 static_assert(sizeof(TrieContent
) ==
56 ThreadSafeTrieRawHashMapBase::TrieContentBaseSize
,
57 "Check header assumption!");
59 class TrieSubtrie final
61 private TrailingObjects
<TrieSubtrie
, LazyAtomicPointer
<TrieNode
>> {
63 using Slot
= LazyAtomicPointer
<TrieNode
>;
65 Slot
&get(size_t I
) { return getTrailingObjects
<Slot
>()[I
]; }
66 TrieNode
*load(size_t I
) { return get(I
).load(); }
68 unsigned size() const { return Size
; }
71 sink(size_t I
, TrieContent
&Content
, size_t NumSubtrieBits
, size_t NewI
,
72 function_ref
<TrieSubtrie
*(std::unique_ptr
<TrieSubtrie
>)> Saver
);
74 static std::unique_ptr
<TrieSubtrie
> create(size_t StartBit
, size_t NumBits
);
76 explicit TrieSubtrie(size_t StartBit
, size_t NumBits
);
78 static bool classof(const TrieNode
*TN
) { return TN
->IsSubtrie
; }
80 static constexpr size_t sizeToAlloc(unsigned NumBits
) {
81 assert(NumBits
< 20 && "Tries should have fewer than ~1M slots");
82 unsigned Count
= 1u << NumBits
;
83 return totalSizeToAlloc
<LazyAtomicPointer
<TrieNode
>>(Count
);
87 // FIXME: Use a bitset to speed up access:
89 // std::array<std::atomic<uint64_t>, NumSlots/64> IsSet;
91 // This will avoid needing to visit sparsely filled slots in
92 // \a ThreadSafeTrieRawHashMapBase::destroyImpl() when there's a non-trivial
95 // It would also greatly speed up iteration, if we add that some day, and
96 // allow get() to return one level sooner.
98 // This would be the algorithm for updating IsSet (after updating Slots):
100 // std::atomic<uint64_t> &Bits = IsSet[I.High];
101 // const uint64_t NewBit = 1ULL << I.Low;
103 // while (!Bits.compare_exchange_weak(Old, Old | NewBit))
107 unsigned StartBit
= 0;
108 unsigned NumBits
= 0;
110 friend class llvm::ThreadSafeTrieRawHashMapBase
;
111 friend class TrailingObjects
;
114 /// Linked list for ownership of tries. The pointer is owned by TrieSubtrie.
115 std::atomic
<TrieSubtrie
*> Next
;
119 std::unique_ptr
<TrieSubtrie
> TrieSubtrie::create(size_t StartBit
,
121 void *Memory
= ::operator new(sizeToAlloc(NumBits
));
122 TrieSubtrie
*S
= ::new (Memory
) TrieSubtrie(StartBit
, NumBits
);
123 return std::unique_ptr
<TrieSubtrie
>(S
);
126 TrieSubtrie::TrieSubtrie(size_t StartBit
, size_t NumBits
)
127 : TrieNode(true), StartBit(StartBit
), NumBits(NumBits
), Size(1u << NumBits
),
129 for (unsigned I
= 0; I
< Size
; ++I
)
130 new (&get(I
)) Slot(nullptr);
133 std::is_trivially_destructible
<LazyAtomicPointer
<TrieNode
>>::value
,
134 "Expected no work in destructor for TrieNode");
137 // Sink the nodes down sub-trie when the object being inserted collides with
138 // the index of existing object in the trie. In this case, a new sub-trie needs
139 // to be allocated to hold existing object.
140 TrieSubtrie
*TrieSubtrie::sink(
141 size_t I
, TrieContent
&Content
, size_t NumSubtrieBits
, size_t NewI
,
142 function_ref
<TrieSubtrie
*(std::unique_ptr
<TrieSubtrie
>)> Saver
) {
143 // Create a new sub-trie that points to the existing object with the new
144 // index for the next level.
145 assert(NumSubtrieBits
> 0);
146 std::unique_ptr
<TrieSubtrie
> S
= create(StartBit
+ NumBits
, NumSubtrieBits
);
149 S
->get(NewI
).store(&Content
);
151 // Using compare_exchange to atomically add back the new sub-trie to the trie
152 // in the place of the exsiting object.
153 TrieNode
*ExistingNode
= &Content
;
155 if (get(I
).compare_exchange_strong(ExistingNode
, S
.get()))
156 return Saver(std::move(S
));
158 // Another thread created a subtrie already. Return it and let "S" be
160 return cast
<TrieSubtrie
>(ExistingNode
);
163 class ThreadSafeTrieRawHashMapBase::ImplType final
164 : private TrailingObjects
<ThreadSafeTrieRawHashMapBase::ImplType
,
167 static std::unique_ptr
<ImplType
> create(size_t StartBit
, size_t NumBits
) {
168 size_t Size
= sizeof(ImplType
) + TrieSubtrie::sizeToAlloc(NumBits
);
169 void *Memory
= ::operator new(Size
);
170 ImplType
*Impl
= ::new (Memory
) ImplType(StartBit
, NumBits
);
171 return std::unique_ptr
<ImplType
>(Impl
);
174 // Save the Subtrie into the ownship list of the trie structure in a
175 // thread-safe way. The ownership transfer is done by compare_exchange the
176 // pointer value inside the unique_ptr.
177 TrieSubtrie
*save(std::unique_ptr
<TrieSubtrie
> S
) {
178 assert(!S
->Next
&& "Expected S to a freshly-constructed leaf");
180 TrieSubtrie
*CurrentHead
= nullptr;
181 // Add ownership of "S" to front of the list, so that Root -> S ->
182 // Root.Next. This works by repeatedly setting S->Next to a candidate value
183 // of Root.Next (initially nullptr), then setting Root.Next to S once the
184 // candidate matches reality.
185 while (!getRoot()->Next
.compare_exchange_weak(CurrentHead
, S
.get()))
186 S
->Next
.exchange(CurrentHead
);
188 // Ownership transferred to subtrie successfully. Release the unique_ptr.
192 // Get the root which is the trailing object.
193 TrieSubtrie
*getRoot() { return getTrailingObjects
<TrieSubtrie
>(); }
195 static void *operator new(size_t Size
) { return ::operator new(Size
); }
196 void operator delete(void *Ptr
) { ::operator delete(Ptr
); }
198 /// FIXME: This should take a function that allocates and constructs the
199 /// content lazily (taking the hash as a separate parameter), in case of
201 ThreadSafeAllocator
<BumpPtrAllocator
> ContentAlloc
;
204 friend class TrailingObjects
;
206 ImplType(size_t StartBit
, size_t NumBits
) {
207 ::new (getRoot()) TrieSubtrie(StartBit
, NumBits
);
211 ThreadSafeTrieRawHashMapBase::ImplType
&
212 ThreadSafeTrieRawHashMapBase::getOrCreateImpl() {
213 if (ImplType
*Impl
= ImplPtr
.load())
216 // Create a new ImplType and store it if another thread doesn't do so first.
217 // If another thread wins this one is destroyed locally.
218 std::unique_ptr
<ImplType
> Impl
= ImplType::create(0, NumRootBits
);
219 ImplType
*ExistingImpl
= nullptr;
221 // If the ownership transferred succesfully, release unique_ptr and return
222 // the pointer to the new ImplType.
223 if (ImplPtr
.compare_exchange_strong(ExistingImpl
, Impl
.get()))
224 return *Impl
.release();
226 // Already created, return the existing ImplType.
227 return *ExistingImpl
;
230 ThreadSafeTrieRawHashMapBase::PointerBase
231 ThreadSafeTrieRawHashMapBase::find(ArrayRef
<uint8_t> Hash
) const {
232 assert(!Hash
.empty() && "Uninitialized hash");
234 ImplType
*Impl
= ImplPtr
.load();
236 return PointerBase();
238 TrieSubtrie
*S
= Impl
->getRoot();
239 TrieHashIndexGenerator IndexGen
{NumRootBits
, NumSubtrieBits
, Hash
};
240 size_t Index
= IndexGen
.next();
241 while (Index
!= IndexGen
.end()) {
242 // Try to set the content.
243 TrieNode
*Existing
= S
->get(Index
);
245 return PointerBase(S
, Index
, *IndexGen
.StartBit
);
247 // Check for an exact match.
248 if (auto *ExistingContent
= dyn_cast
<TrieContent
>(Existing
))
249 return ExistingContent
->getHash() == Hash
250 ? PointerBase(ExistingContent
->getValuePointer())
251 : PointerBase(S
, Index
, *IndexGen
.StartBit
);
253 Index
= IndexGen
.next();
254 S
= cast
<TrieSubtrie
>(Existing
);
256 llvm_unreachable("failed to locate the node after consuming all hash bytes");
259 ThreadSafeTrieRawHashMapBase::PointerBase
ThreadSafeTrieRawHashMapBase::insert(
260 PointerBase Hint
, ArrayRef
<uint8_t> Hash
,
261 function_ref
<const uint8_t *(void *Mem
, ArrayRef
<uint8_t> Hash
)>
263 assert(!Hash
.empty() && "Uninitialized hash");
265 ImplType
&Impl
= getOrCreateImpl();
266 TrieSubtrie
*S
= Impl
.getRoot();
267 TrieHashIndexGenerator IndexGen
{NumRootBits
, NumSubtrieBits
, Hash
};
270 S
= static_cast<TrieSubtrie
*>(Hint
.P
);
271 Index
= IndexGen
.hint(Hint
.I
, Hint
.B
);
273 Index
= IndexGen
.next();
276 while (Index
!= IndexGen
.end()) {
277 // Load the node from the slot, allocating and calling the constructor if
278 // the slot is empty.
279 bool Generated
= false;
280 TrieNode
&Existing
= S
->get(Index
).loadOrGenerate([&]() {
283 // Construct the value itself at the tail.
284 uint8_t *Memory
= reinterpret_cast<uint8_t *>(
285 Impl
.ContentAlloc
.Allocate(ContentAllocSize
, ContentAllocAlign
));
286 const uint8_t *HashStorage
= Constructor(Memory
+ ContentOffset
, Hash
);
288 // Construct the TrieContent header, passing in the offset to the hash.
289 TrieContent
*Content
= ::new (Memory
)
290 TrieContent(ContentOffset
, Hash
.size(), HashStorage
- Memory
);
291 assert(Hash
== Content
->getHash() && "Hash not properly initialized");
294 // If we just generated it, return it!
296 return PointerBase(cast
<TrieContent
>(Existing
).getValuePointer());
298 if (auto *ST
= dyn_cast
<TrieSubtrie
>(&Existing
)) {
300 Index
= IndexGen
.next();
304 // Return the existing content if it's an exact match!
305 auto &ExistingContent
= cast
<TrieContent
>(Existing
);
306 if (ExistingContent
.getHash() == Hash
)
307 return PointerBase(ExistingContent
.getValuePointer());
309 // Sink the existing content as long as the indexes match.
310 size_t NextIndex
= IndexGen
.next();
311 while (NextIndex
!= IndexGen
.end()) {
312 size_t NewIndexForExistingContent
=
313 IndexGen
.getCollidingBits(ExistingContent
.getHash());
314 S
= S
->sink(Index
, ExistingContent
, IndexGen
.getNumBits(),
315 NewIndexForExistingContent
,
316 [&Impl
](std::unique_ptr
<TrieSubtrie
> S
) {
317 return Impl
.save(std::move(S
));
321 // Found the difference.
322 if (NextIndex
!= NewIndexForExistingContent
)
325 NextIndex
= IndexGen
.next();
328 llvm_unreachable("failed to insert the node after consuming all hash bytes");
331 ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase(
332 size_t ContentAllocSize
, size_t ContentAllocAlign
, size_t ContentOffset
,
333 std::optional
<size_t> NumRootBits
, std::optional
<size_t> NumSubtrieBits
)
334 : ContentAllocSize(ContentAllocSize
), ContentAllocAlign(ContentAllocAlign
),
335 ContentOffset(ContentOffset
),
336 NumRootBits(NumRootBits
? *NumRootBits
: DefaultNumRootBits
),
337 NumSubtrieBits(NumSubtrieBits
? *NumSubtrieBits
: DefaultNumSubtrieBits
),
339 // Assertion checks for reasonable configuration. The settings below are not
340 // hard limits on most platforms, but a reasonable configuration should fall
341 // within those limits.
342 assert((!NumRootBits
|| *NumRootBits
< 20) &&
343 "Root should have fewer than ~1M slots");
344 assert((!NumSubtrieBits
|| *NumSubtrieBits
< 10) &&
345 "Subtries should have fewer than ~1K slots");
348 ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase(
349 ThreadSafeTrieRawHashMapBase
&&RHS
)
350 : ContentAllocSize(RHS
.ContentAllocSize
),
351 ContentAllocAlign(RHS
.ContentAllocAlign
),
352 ContentOffset(RHS
.ContentOffset
), NumRootBits(RHS
.NumRootBits
),
353 NumSubtrieBits(RHS
.NumSubtrieBits
) {
354 // Steal the root from RHS.
355 ImplPtr
= RHS
.ImplPtr
.exchange(nullptr);
358 ThreadSafeTrieRawHashMapBase::~ThreadSafeTrieRawHashMapBase() {
359 assert(!ImplPtr
.load() && "Expected subclass to call destroyImpl()");
362 void ThreadSafeTrieRawHashMapBase::destroyImpl(
363 function_ref
<void(void *)> Destructor
) {
364 std::unique_ptr
<ImplType
> Impl(ImplPtr
.exchange(nullptr));
368 // Destroy content nodes throughout trie. Avoid destroying any subtries since
369 // we need TrieNode::classof() to find the content nodes.
371 // FIXME: Once we have bitsets (see FIXME in TrieSubtrie class), use them
372 // facilitate sparse iteration here.
374 for (TrieSubtrie
*Trie
= Impl
->getRoot(); Trie
; Trie
= Trie
->Next
.load())
375 for (unsigned I
= 0; I
< Trie
->size(); ++I
)
376 if (auto *Content
= dyn_cast_or_null
<TrieContent
>(Trie
->load(I
)))
377 Destructor(Content
->getValuePointer());
379 // Destroy the subtries. Incidentally, this destroys them in the reverse order
381 TrieSubtrie
*Trie
= Impl
->getRoot()->Next
;
383 TrieSubtrie
*Next
= Trie
->Next
.exchange(nullptr);
389 ThreadSafeTrieRawHashMapBase::PointerBase
390 ThreadSafeTrieRawHashMapBase::getRoot() const {
391 ImplType
*Impl
= ImplPtr
.load();
393 return PointerBase();
394 return PointerBase(Impl
->getRoot());
397 unsigned ThreadSafeTrieRawHashMapBase::getStartBit(
398 ThreadSafeTrieRawHashMapBase::PointerBase P
) const {
399 assert(!P
.isHint() && "Not a valid trie");
402 if (auto *S
= dyn_cast
<TrieSubtrie
>((TrieNode
*)P
.P
))
407 unsigned ThreadSafeTrieRawHashMapBase::getNumBits(
408 ThreadSafeTrieRawHashMapBase::PointerBase P
) const {
409 assert(!P
.isHint() && "Not a valid trie");
412 if (auto *S
= dyn_cast
<TrieSubtrie
>((TrieNode
*)P
.P
))
417 unsigned ThreadSafeTrieRawHashMapBase::getNumSlotUsed(
418 ThreadSafeTrieRawHashMapBase::PointerBase P
) const {
419 assert(!P
.isHint() && "Not a valid trie");
422 auto *S
= dyn_cast
<TrieSubtrie
>((TrieNode
*)P
.P
);
426 for (unsigned I
= 0, E
= S
->size(); I
< E
; ++I
)
432 std::string
ThreadSafeTrieRawHashMapBase::getTriePrefixAsString(
433 ThreadSafeTrieRawHashMapBase::PointerBase P
) const {
434 assert(!P
.isHint() && "Not a valid trie");
438 auto *S
= dyn_cast
<TrieSubtrie
>((TrieNode
*)P
.P
);
439 if (!S
|| !S
->IsSubtrie
)
442 // Find a TrieContent node which has hash stored. Depth search following the
443 // first used slot until a TrieContent node is found.
444 TrieSubtrie
*Current
= S
;
445 TrieContent
*Node
= nullptr;
447 TrieSubtrie
*Next
= nullptr;
448 // Find first used slot in the trie.
449 for (unsigned I
= 0, E
= Current
->size(); I
< E
; ++I
) {
450 auto *S
= Current
->load(I
);
454 if (auto *Content
= dyn_cast
<TrieContent
>(S
))
456 else if (auto *Sub
= dyn_cast
<TrieSubtrie
>(S
))
465 // Continue to the next level if the node is not found.
469 assert(Node
&& "malformed trie, cannot find TrieContent on leaf node");
470 // The prefix for the current trie is the first `StartBit` of the content
471 // stored underneath this subtrie.
473 raw_string_ostream
SS(Str
);
475 unsigned StartFullBytes
= (S
->StartBit
+ 1) / 8 - 1;
476 SS
<< toHex(toStringRef(Node
->getHash()).take_front(StartFullBytes
),
479 // For the part of the prefix that doesn't fill a byte, print raw bit values.
481 for (unsigned I
= StartFullBytes
* 8, E
= S
->StartBit
; I
< E
; ++I
) {
482 unsigned Index
= I
/ 8;
483 unsigned Offset
= 7 - I
% 8;
484 Bits
.push_back('0' + ((Node
->getHash()[Index
] >> Offset
) & 1));
488 SS
<< "[" << Bits
<< "]";
493 unsigned ThreadSafeTrieRawHashMapBase::getNumTries() const {
494 ImplType
*Impl
= ImplPtr
.load();
498 for (TrieSubtrie
*Trie
= Impl
->getRoot(); Trie
; Trie
= Trie
->Next
.load())
503 ThreadSafeTrieRawHashMapBase::PointerBase
504 ThreadSafeTrieRawHashMapBase::getNextTrie(
505 ThreadSafeTrieRawHashMapBase::PointerBase P
) const {
506 assert(!P
.isHint() && "Not a valid trie");
508 return PointerBase();
509 auto *S
= dyn_cast
<TrieSubtrie
>((TrieNode
*)P
.P
);
511 return PointerBase();
512 if (auto *E
= S
->Next
.load())
513 return PointerBase(E
);
514 return PointerBase();