[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / mlir / lib / Support / StorageUniquer.cpp
blob2e9b17e1e1c76cd9d990a1b2c75b4443c59ef8a3
1 //===- StorageUniquer.cpp - Common Storage Class Uniquer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Support/StorageUniquer.h"
11 #include "mlir/Support/LLVM.h"
12 #include "mlir/Support/ThreadLocalCache.h"
13 #include "mlir/Support/TypeID.h"
14 #include "llvm/Support/RWMutex.h"
16 using namespace mlir;
17 using namespace mlir::detail;
19 namespace {
20 /// This class represents a uniquer for storage instances of a specific type
21 /// that has parametric storage. It contains all of the necessary data to unique
22 /// storage instances in a thread safe way. This allows for the main uniquer to
23 /// bucket each of the individual sub-types removing the need to lock the main
24 /// uniquer itself.
25 class ParametricStorageUniquer {
26 public:
27 using BaseStorage = StorageUniquer::BaseStorage;
28 using StorageAllocator = StorageUniquer::StorageAllocator;
30 /// A lookup key for derived instances of storage objects.
31 struct LookupKey {
32 /// The known hash value of the key.
33 unsigned hashValue;
35 /// An equality function for comparing with an existing storage instance.
36 function_ref<bool(const BaseStorage *)> isEqual;
39 private:
40 /// A utility wrapper object representing a hashed storage object. This class
41 /// contains a storage object and an existing computed hash value.
42 struct HashedStorage {
43 HashedStorage(unsigned hashValue = 0, BaseStorage *storage = nullptr)
44 : hashValue(hashValue), storage(storage) {}
45 unsigned hashValue;
46 BaseStorage *storage;
49 /// Storage info for derived TypeStorage objects.
50 struct StorageKeyInfo {
51 static inline HashedStorage getEmptyKey() {
52 return HashedStorage(0, DenseMapInfo<BaseStorage *>::getEmptyKey());
54 static inline HashedStorage getTombstoneKey() {
55 return HashedStorage(0, DenseMapInfo<BaseStorage *>::getTombstoneKey());
58 static inline unsigned getHashValue(const HashedStorage &key) {
59 return key.hashValue;
61 static inline unsigned getHashValue(const LookupKey &key) {
62 return key.hashValue;
65 static inline bool isEqual(const HashedStorage &lhs,
66 const HashedStorage &rhs) {
67 return lhs.storage == rhs.storage;
69 static inline bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
70 if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
71 return false;
72 // Invoke the equality function on the lookup key.
73 return lhs.isEqual(rhs.storage);
76 using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
78 /// This class represents a single shard of the uniquer. The uniquer uses a
79 /// set of shards to allow for multiple threads to create instances with less
80 /// lock contention.
81 struct Shard {
82 /// The set containing the allocated storage instances.
83 StorageTypeSet instances;
85 #if LLVM_ENABLE_THREADS != 0
86 /// A mutex to keep uniquing thread-safe.
87 llvm::sys::SmartRWMutex<true> mutex;
88 #endif
91 /// Get or create an instance of a param derived type in an thread-unsafe
92 /// fashion.
93 BaseStorage *getOrCreateUnsafe(Shard &shard, LookupKey &key,
94 function_ref<BaseStorage *()> ctorFn) {
95 auto existing = shard.instances.insert_as({key.hashValue}, key);
96 BaseStorage *&storage = existing.first->storage;
97 if (existing.second)
98 storage = ctorFn();
99 return storage;
102 /// Destroy all of the storage instances within the given shard.
103 void destroyShardInstances(Shard &shard) {
104 if (!destructorFn)
105 return;
106 for (HashedStorage &instance : shard.instances)
107 destructorFn(instance.storage);
110 public:
111 #if LLVM_ENABLE_THREADS != 0
112 /// Initialize the storage uniquer with a given number of storage shards to
113 /// use. The provided shard number is required to be a valid power of 2. The
114 /// destructor function is used to destroy any allocated storage instances.
115 ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
116 size_t numShards = 8)
117 : shards(new std::atomic<Shard *>[numShards]), numShards(numShards),
118 destructorFn(destructorFn) {
119 assert(llvm::isPowerOf2_64(numShards) &&
120 "the number of shards is required to be a power of 2");
121 for (size_t i = 0; i < numShards; i++)
122 shards[i].store(nullptr, std::memory_order_relaxed);
124 ~ParametricStorageUniquer() {
125 // Free all of the allocated shards.
126 for (size_t i = 0; i != numShards; ++i) {
127 if (Shard *shard = shards[i].load()) {
128 destroyShardInstances(*shard);
129 delete shard;
133 /// Get or create an instance of a parametric type.
134 BaseStorage *getOrCreate(bool threadingIsEnabled, unsigned hashValue,
135 function_ref<bool(const BaseStorage *)> isEqual,
136 function_ref<BaseStorage *()> ctorFn) {
137 Shard &shard = getShard(hashValue);
138 ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
139 if (!threadingIsEnabled)
140 return getOrCreateUnsafe(shard, lookupKey, ctorFn);
142 // Check for a instance of this object in the local cache.
143 auto localIt = localCache->insert_as({hashValue}, lookupKey);
144 BaseStorage *&localInst = localIt.first->storage;
145 if (localInst)
146 return localInst;
148 // Check for an existing instance in read-only mode.
150 llvm::sys::SmartScopedReader<true> typeLock(shard.mutex);
151 auto it = shard.instances.find_as(lookupKey);
152 if (it != shard.instances.end())
153 return localInst = it->storage;
156 // Acquire a writer-lock so that we can safely create the new storage
157 // instance.
158 llvm::sys::SmartScopedWriter<true> typeLock(shard.mutex);
159 return localInst = getOrCreateUnsafe(shard, lookupKey, ctorFn);
162 /// Run a mutation function on the provided storage object in a thread-safe
163 /// way.
164 LogicalResult mutate(bool threadingIsEnabled, BaseStorage *storage,
165 function_ref<LogicalResult()> mutationFn) {
166 if (!threadingIsEnabled)
167 return mutationFn();
169 // Get a shard to use for mutating this storage instance. It doesn't need to
170 // be the same shard as the original allocation, but does need to be
171 // deterministic.
172 Shard &shard = getShard(llvm::hash_value(storage));
173 llvm::sys::SmartScopedWriter<true> lock(shard.mutex);
174 return mutationFn();
177 private:
178 /// Return the shard used for the given hash value.
179 Shard &getShard(unsigned hashValue) {
180 // Get a shard number from the provided hashvalue.
181 unsigned shardNum = hashValue & (numShards - 1);
183 // Try to acquire an already initialized shard.
184 Shard *shard = shards[shardNum].load(std::memory_order_acquire);
185 if (shard)
186 return *shard;
188 // Otherwise, try to allocate a new shard.
189 Shard *newShard = new Shard();
190 if (shards[shardNum].compare_exchange_strong(shard, newShard))
191 return *newShard;
193 // If one was allocated before we can initialize ours, delete ours.
194 delete newShard;
195 return *shard;
198 /// A thread local cache for storage objects. This helps to reduce the lock
199 /// contention when an object already existing in the cache.
200 ThreadLocalCache<StorageTypeSet> localCache;
202 /// A set of uniquer shards to allow for further bucketing accesses for
203 /// instances of this storage type. Each shard is lazily initialized to reduce
204 /// the overhead when only a small amount of shards are in use.
205 std::unique_ptr<std::atomic<Shard *>[]> shards;
207 /// The number of available shards.
208 size_t numShards;
210 /// Function to used to destruct any allocated storage instances.
211 function_ref<void(BaseStorage *)> destructorFn;
213 #else
214 /// If multi-threading is disabled, ignore the shard parameter as we will
215 /// always use one shard. The destructor function is used to destroy any
216 /// allocated storage instances.
217 ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
218 size_t numShards = 0)
219 : destructorFn(destructorFn) {}
220 ~ParametricStorageUniquer() { destroyShardInstances(shard); }
222 /// Get or create an instance of a parametric type.
223 BaseStorage *
224 getOrCreate(bool threadingIsEnabled, unsigned hashValue,
225 function_ref<bool(const BaseStorage *)> isEqual,
226 function_ref<BaseStorage *()> ctorFn) {
227 ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
228 return getOrCreateUnsafe(shard, lookupKey, ctorFn);
230 /// Run a mutation function on the provided storage object in a thread-safe
231 /// way.
232 LogicalResult
233 mutate(bool threadingIsEnabled, BaseStorage *storage,
234 function_ref<LogicalResult()> mutationFn) {
235 return mutationFn();
238 private:
239 /// The main uniquer shard that is used for allocating storage instances.
240 Shard shard;
242 /// Function to used to destruct any allocated storage instances.
243 function_ref<void(BaseStorage *)> destructorFn;
244 #endif
246 } // namespace
248 namespace mlir {
249 namespace detail {
250 /// This is the implementation of the StorageUniquer class.
251 struct StorageUniquerImpl {
252 using BaseStorage = StorageUniquer::BaseStorage;
253 using StorageAllocator = StorageUniquer::StorageAllocator;
255 //===--------------------------------------------------------------------===//
256 // Parametric Storage
257 //===--------------------------------------------------------------------===//
259 /// Check if an instance of a parametric storage class exists.
260 bool hasParametricStorage(TypeID id) { return parametricUniquers.count(id); }
262 /// Get or create an instance of a parametric type.
263 BaseStorage *
264 getOrCreate(TypeID id, unsigned hashValue,
265 function_ref<bool(const BaseStorage *)> isEqual,
266 function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
267 assert(parametricUniquers.count(id) &&
268 "creating unregistered storage instance");
269 ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
270 return storageUniquer.getOrCreate(
271 threadingIsEnabled, hashValue, isEqual,
272 [&] { return ctorFn(getThreadSafeAllocator()); });
275 /// Run a mutation function on the provided storage object in a thread-safe
276 /// way.
277 LogicalResult
278 mutate(TypeID id, BaseStorage *storage,
279 function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
280 assert(parametricUniquers.count(id) &&
281 "mutating unregistered storage instance");
282 ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
283 return storageUniquer.mutate(threadingIsEnabled, storage, [&] {
284 return mutationFn(getThreadSafeAllocator());
288 /// Return an allocator that can be used to safely allocate instances on the
289 /// current thread.
290 StorageAllocator &getThreadSafeAllocator() {
291 #if LLVM_ENABLE_THREADS != 0
292 if (!threadingIsEnabled)
293 return allocator;
295 // If the allocator has not been initialized, create a new one.
296 StorageAllocator *&threadAllocator = threadSafeAllocator.get();
297 if (!threadAllocator) {
298 threadAllocator = new StorageAllocator();
300 // Record this allocator, given that we don't want it to be destroyed when
301 // the thread dies.
302 llvm::sys::SmartScopedLock<true> lock(threadAllocatorMutex);
303 threadAllocators.push_back(
304 std::unique_ptr<StorageAllocator>(threadAllocator));
307 return *threadAllocator;
308 #else
309 return allocator;
310 #endif
313 //===--------------------------------------------------------------------===//
314 // Singleton Storage
315 //===--------------------------------------------------------------------===//
317 /// Get or create an instance of a singleton storage class.
318 BaseStorage *getSingleton(TypeID id) {
319 BaseStorage *singletonInstance = singletonInstances[id];
320 assert(singletonInstance && "expected singleton instance to exist");
321 return singletonInstance;
324 /// Check if an instance of a singleton storage class exists.
325 bool hasSingleton(TypeID id) const { return singletonInstances.count(id); }
327 //===--------------------------------------------------------------------===//
328 // Instance Storage
329 //===--------------------------------------------------------------------===//
331 #if LLVM_ENABLE_THREADS != 0
332 /// A thread local set of allocators used for uniquing parametric instances,
333 /// or other data allocated in thread volatile situations.
334 ThreadLocalCache<StorageAllocator *> threadSafeAllocator;
336 /// All of the allocators that have been created for thread based allocation.
337 std::vector<std::unique_ptr<StorageAllocator>> threadAllocators;
339 /// A mutex used for safely adding a new thread allocator.
340 llvm::sys::SmartMutex<true> threadAllocatorMutex;
341 #endif
343 /// Main allocator used for uniquing singleton instances, and other state when
344 /// thread safety is guaranteed.
345 StorageAllocator allocator;
347 /// Map of type ids to the storage uniquer to use for registered objects.
348 DenseMap<TypeID, std::unique_ptr<ParametricStorageUniquer>>
349 parametricUniquers;
351 /// Map of type ids to a singleton instance when the storage class is a
352 /// singleton.
353 DenseMap<TypeID, BaseStorage *> singletonInstances;
355 /// Flag specifying if multi-threading is enabled within the uniquer.
356 bool threadingIsEnabled = true;
358 } // namespace detail
359 } // namespace mlir
361 StorageUniquer::StorageUniquer() : impl(new StorageUniquerImpl()) {}
362 StorageUniquer::~StorageUniquer() = default;
364 /// Set the flag specifying if multi-threading is disabled within the uniquer.
365 void StorageUniquer::disableMultithreading(bool disable) {
366 impl->threadingIsEnabled = !disable;
369 /// Implementation for getting/creating an instance of a derived type with
370 /// parametric storage.
371 auto StorageUniquer::getParametricStorageTypeImpl(
372 TypeID id, unsigned hashValue,
373 function_ref<bool(const BaseStorage *)> isEqual,
374 function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
375 return impl->getOrCreate(id, hashValue, isEqual, ctorFn);
378 /// Implementation for registering an instance of a derived type with
379 /// parametric storage.
380 void StorageUniquer::registerParametricStorageTypeImpl(
381 TypeID id, function_ref<void(BaseStorage *)> destructorFn) {
382 impl->parametricUniquers.try_emplace(
383 id, std::make_unique<ParametricStorageUniquer>(destructorFn));
386 /// Implementation for getting an instance of a derived type with default
387 /// storage.
388 auto StorageUniquer::getSingletonImpl(TypeID id) -> BaseStorage * {
389 return impl->getSingleton(id);
392 /// Test is the storage singleton is initialized.
393 bool StorageUniquer::isSingletonStorageInitialized(TypeID id) {
394 return impl->hasSingleton(id);
397 /// Test is the parametric storage is initialized.
398 bool StorageUniquer::isParametricStorageInitialized(TypeID id) {
399 return impl->hasParametricStorage(id);
402 /// Implementation for registering an instance of a derived type with default
403 /// storage.
404 void StorageUniquer::registerSingletonImpl(
405 TypeID id, function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
406 assert(!impl->singletonInstances.count(id) &&
407 "storage class already registered");
408 impl->singletonInstances.try_emplace(id, ctorFn(impl->allocator));
411 /// Implementation for mutating an instance of a derived storage.
412 LogicalResult StorageUniquer::mutateImpl(
413 TypeID id, BaseStorage *storage,
414 function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
415 return impl->mutate(id, storage, mutationFn);