1 //===----------- MemoryManager.h - Target independent memory manager ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Target independent memory manager.
11 //===----------------------------------------------------------------------===//
13 #ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_MEMORYMANAGER_H
14 #define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_MEMORYMANAGER_H
21 #include <unordered_map>
25 #include "omptargetplugin.h"
27 /// Base class of per-device allocator.
28 class DeviceAllocatorTy
{
30 virtual ~DeviceAllocatorTy() = default;
32 /// Allocate a memory of size \p Size . \p HstPtr is used to assist the
34 virtual void *allocate(size_t Size
, void *HstPtr
, TargetAllocTy Kind
) = 0;
36 /// Delete the pointer \p TgtPtr on the device
37 virtual int free(void *TgtPtr
) = 0;
40 /// Class of memory manager. The memory manager is per-device by using
41 /// per-device allocator. Therefore, each plugin using memory manager should
42 /// have an allocator for each device.
43 class MemoryManagerTy
{
44 static constexpr const size_t BucketSize
[] = {
45 0, 1U << 2, 1U << 3, 1U << 4, 1U << 5, 1U << 6, 1U << 7,
46 1U << 8, 1U << 9, 1U << 10, 1U << 11, 1U << 12, 1U << 13};
48 static constexpr const int NumBuckets
=
49 sizeof(BucketSize
) / sizeof(BucketSize
[0]);
51 /// Find the previous number that is power of 2 given a number that is not
53 static size_t floorToPowerOfTwo(size_t Num
) {
59 #if INTPTR_MAX == INT64_MAX
61 #elif INTPTR_MAX == INT32_MAX
62 // Do nothing with 32-bit
64 #error Unsupported architecture
70 /// Find a suitable bucket
71 static int findBucket(size_t Size
) {
72 const size_t F
= floorToPowerOfTwo(Size
);
74 DP("findBucket: Size %zu is floored to %zu.\n", Size
, F
);
76 int L
= 0, H
= NumBuckets
- 1;
79 if (BucketSize
[M
] == F
)
81 if (BucketSize
[M
] > F
)
87 assert(L
>= 0 && L
< NumBuckets
&& "L is out of range");
89 DP("findBucket: Size %zu goes to bucket %d\n", Size
, L
);
94 /// A structure stores the meta data of a target pointer
102 NodeTy(size_t Size
, void *Ptr
) : Size(Size
), Ptr(Ptr
) {}
105 /// To make \p NodePtrTy ordered when they're put into \p std::multiset.
107 bool operator()(const NodeTy
&LHS
, const NodeTy
&RHS
) const {
108 return LHS
.Size
< RHS
.Size
;
112 /// A \p FreeList is a set of Nodes. We're using \p std::multiset here to make
113 /// the look up procedure more efficient.
114 using FreeListTy
= std::multiset
<std::reference_wrapper
<NodeTy
>, NodeCmpTy
>;
116 /// A list of \p FreeListTy entries, each of which is a \p std::multiset of
117 /// Nodes whose size is less or equal to a specific bucket size.
118 std::vector
<FreeListTy
> FreeLists
;
119 /// A list of mutex for each \p FreeListTy entry
120 std::vector
<std::mutex
> FreeListLocks
;
121 /// A table to map from a target pointer to its node
122 std::unordered_map
<void *, NodeTy
> PtrToNodeTable
;
123 /// The mutex for the table \p PtrToNodeTable
124 std::mutex MapTableLock
;
126 /// The reference to a device allocator
127 DeviceAllocatorTy
&DeviceAllocator
;
129 /// The threshold to manage memory using memory manager. If the request size
130 /// is larger than \p SizeThreshold, the allocation will not be managed by the
132 size_t SizeThreshold
= 1U << 13;
134 /// Request memory from target device
135 void *allocateOnDevice(size_t Size
, void *HstPtr
) const {
136 return DeviceAllocator
.allocate(Size
, HstPtr
, TARGET_ALLOC_DEVICE
);
139 /// Deallocate data on device
140 int deleteOnDevice(void *Ptr
) const { return DeviceAllocator
.free(Ptr
); }
142 /// This function is called when it tries to allocate memory on device but the
143 /// device returns out of memory. It will first free all memory in the
144 /// FreeList and try to allocate again.
145 void *freeAndAllocate(size_t Size
, void *HstPtr
) {
146 std::vector
<void *> RemoveList
;
148 // Deallocate all memory in FreeList
149 for (int I
= 0; I
< NumBuckets
; ++I
) {
150 FreeListTy
&List
= FreeLists
[I
];
151 std::lock_guard
<std::mutex
> Lock(FreeListLocks
[I
]);
154 for (const NodeTy
&N
: List
) {
155 deleteOnDevice(N
.Ptr
);
156 RemoveList
.push_back(N
.Ptr
);
158 FreeLists
[I
].clear();
161 // Remove all nodes in the map table which have been released
162 if (!RemoveList
.empty()) {
163 std::lock_guard
<std::mutex
> LG(MapTableLock
);
164 for (void *P
: RemoveList
)
165 PtrToNodeTable
.erase(P
);
168 // Try allocate memory again
169 return allocateOnDevice(Size
, HstPtr
);
172 /// The goal is to allocate memory on the device. It first tries to
173 /// allocate directly on the device. If a \p nullptr is returned, it might
174 /// be because the device is OOM. In that case, it will free all unused
175 /// memory and then try again.
176 void *allocateOrFreeAndAllocateOnDevice(size_t Size
, void *HstPtr
) {
177 void *TgtPtr
= allocateOnDevice(Size
, HstPtr
);
178 // We cannot get memory from the device. It might be due to OOM. Let's
179 // free all memory in FreeLists and try again.
180 if (TgtPtr
== nullptr) {
181 DP("Failed to get memory on device. Free all memory in FreeLists and "
183 TgtPtr
= freeAndAllocate(Size
, HstPtr
);
186 if (TgtPtr
== nullptr)
187 DP("Still cannot get memory on device probably because the device is "
194 /// Constructor. If \p Threshold is non-zero, then the default threshold will
195 /// be overwritten by \p Threshold.
196 MemoryManagerTy(DeviceAllocatorTy
&DeviceAllocator
, size_t Threshold
= 0)
197 : FreeLists(NumBuckets
), FreeListLocks(NumBuckets
),
198 DeviceAllocator(DeviceAllocator
) {
200 SizeThreshold
= Threshold
;
205 for (auto Itr
= PtrToNodeTable
.begin(); Itr
!= PtrToNodeTable
.end();
207 assert(Itr
->second
.Ptr
&& "nullptr in map table");
208 deleteOnDevice(Itr
->second
.Ptr
);
212 /// Allocate memory of size \p Size from target device. \p HstPtr is used to
213 /// assist the allocation.
214 void *allocate(size_t Size
, void *HstPtr
) {
215 // If the size is zero, we will not bother the target device. Just return
220 DP("MemoryManagerTy::allocate: size %zu with host pointer " DPxMOD
".\n",
221 Size
, DPxPTR(HstPtr
));
223 // If the size is greater than the threshold, allocate it directly from
225 if (Size
> SizeThreshold
) {
226 DP("%zu is greater than the threshold %zu. Allocate it directly from "
228 Size
, SizeThreshold
);
229 void *TgtPtr
= allocateOrFreeAndAllocateOnDevice(Size
, HstPtr
);
231 DP("Got target pointer " DPxMOD
". Return directly.\n", DPxPTR(TgtPtr
));
236 NodeTy
*NodePtr
= nullptr;
238 // Try to get a node from FreeList
240 const int B
= findBucket(Size
);
241 FreeListTy
&List
= FreeLists
[B
];
243 NodeTy
TempNode(Size
, nullptr);
244 std::lock_guard
<std::mutex
> LG(FreeListLocks
[B
]);
245 const auto Itr
= List
.find(TempNode
);
247 if (Itr
!= List
.end()) {
248 NodePtr
= &Itr
->get();
253 if (NodePtr
!= nullptr)
254 DP("Find one node " DPxMOD
" in the bucket.\n", DPxPTR(NodePtr
));
256 // We cannot find a valid node in FreeLists. Let's allocate on device and
257 // create a node for it.
258 if (NodePtr
== nullptr) {
259 DP("Cannot find a node in the FreeLists. Allocate on device.\n");
260 // Allocate one on device
261 void *TgtPtr
= allocateOrFreeAndAllocateOnDevice(Size
, HstPtr
);
263 if (TgtPtr
== nullptr)
266 // Create a new node and add it into the map table
268 std::lock_guard
<std::mutex
> Guard(MapTableLock
);
269 auto Itr
= PtrToNodeTable
.emplace(TgtPtr
, NodeTy(Size
, TgtPtr
));
270 NodePtr
= &Itr
.first
->second
;
273 DP("Node address " DPxMOD
", target pointer " DPxMOD
", size %zu\n",
274 DPxPTR(NodePtr
), DPxPTR(TgtPtr
), Size
);
277 assert(NodePtr
&& "NodePtr should not be nullptr at this point");
282 /// Deallocate memory pointed by \p TgtPtr
283 int free(void *TgtPtr
) {
284 DP("MemoryManagerTy::free: target memory " DPxMOD
".\n", DPxPTR(TgtPtr
));
288 // Look it up into the table
290 std::lock_guard
<std::mutex
> G(MapTableLock
);
291 auto Itr
= PtrToNodeTable
.find(TgtPtr
);
293 // We don't remove the node from the map table because the map does not
295 if (Itr
!= PtrToNodeTable
.end())
299 // The memory is not managed by the manager
301 DP("Cannot find its node. Delete it on device directly.\n");
302 return deleteOnDevice(TgtPtr
);
305 // Insert the node to the free list
306 const int B
= findBucket(P
->Size
);
308 DP("Found its node " DPxMOD
". Insert it to bucket %d.\n", DPxPTR(P
), B
);
311 std::lock_guard
<std::mutex
> G(FreeListLocks
[B
]);
312 FreeLists
[B
].insert(*P
);
315 return OFFLOAD_SUCCESS
;
318 /// Get the size threshold from the environment variable
319 /// \p LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD . Returns a <tt>
320 /// std::pair<size_t, bool> </tt> where the first element represents the
321 /// threshold and the second element represents whether user disables memory
322 /// manager explicitly by setting the var to 0. If user doesn't specify
323 /// anything, returns <0, true>.
324 static std::pair
<size_t, bool> getSizeThresholdFromEnv() {
325 size_t Threshold
= 0;
327 if (const char *Env
=
328 std::getenv("LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD")) {
329 Threshold
= std::stoul(Env
);
330 if (Threshold
== 0) {
331 DP("Disabled memory manager as user set "
332 "LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD=0.\n");
333 return std::make_pair(0, false);
337 return std::make_pair(Threshold
, true);
341 // GCC still cannot handle the static data member like Clang so we still need
343 constexpr const size_t MemoryManagerTy::BucketSize
[];
344 constexpr const int MemoryManagerTy::NumBuckets
;
346 #endif // LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_MEMORYMANAGER_H