[JITLink][arm64] Support arm64e JIT'd code (initially enabled for MachO only).
[llvm-project.git] / offload / src / PluginManager.cpp
blob315b953f9b31ac6787303ae962d6ea4d73f21f56
1 //===-- PluginManager.cpp - Plugin loading and communication API ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Functionality for handling plugins.
11 //===----------------------------------------------------------------------===//
13 #include "PluginManager.h"
14 #include "Shared/Debug.h"
15 #include "Shared/Profile.h"
16 #include "device.h"
18 #include "llvm/Support/Error.h"
19 #include "llvm/Support/ErrorHandling.h"
20 #include <memory>
22 using namespace llvm;
23 using namespace llvm::sys;
25 PluginManager *PM = nullptr;
27 // Every plugin exports this method to create an instance of the plugin type.
28 #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
29 #include "Shared/Targets.def"
31 void PluginManager::init() {
32 TIMESCOPE();
33 DP("Loading RTLs...\n");
35 // Attempt to create an instance of each supported plugin.
36 #define PLUGIN_TARGET(Name) \
37 do { \
38 Plugins.emplace_back( \
39 std::unique_ptr<GenericPluginTy>(createPlugin_##Name())); \
40 } while (false);
41 #include "Shared/Targets.def"
43 DP("RTLs loaded!\n");
46 void PluginManager::deinit() {
47 TIMESCOPE();
48 DP("Unloading RTLs...\n");
50 for (auto &Plugin : Plugins) {
51 if (!Plugin->is_initialized())
52 continue;
54 if (auto Err = Plugin->deinit()) {
55 [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
56 DP("Failed to deinit plugin: %s\n", InfoMsg.c_str());
58 Plugin.release();
61 DP("RTLs unloaded!\n");
64 bool PluginManager::initializePlugin(GenericPluginTy &Plugin) {
65 if (Plugin.is_initialized())
66 return true;
68 if (auto Err = Plugin.init()) {
69 [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
70 DP("Failed to init plugin: %s\n", InfoMsg.c_str());
71 return false;
74 DP("Registered plugin %s with %d visible device(s)\n", Plugin.getName(),
75 Plugin.number_of_devices());
76 return true;
79 bool PluginManager::initializeDevice(GenericPluginTy &Plugin,
80 int32_t DeviceId) {
81 if (Plugin.is_device_initialized(DeviceId)) {
82 auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
83 (*ExclusiveDevicesAccessor)[PM->DeviceIds[std::make_pair(&Plugin,
84 DeviceId)]]
85 ->setHasPendingImages(true);
86 return true;
89 // Initialize the device information for the RTL we are about to use.
90 auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
92 int32_t UserId = ExclusiveDevicesAccessor->size();
94 // Set the device identifier offset in the plugin.
95 #ifdef OMPT_SUPPORT
96 Plugin.set_device_identifier(UserId, DeviceId);
97 #endif
99 auto Device = std::make_unique<DeviceTy>(&Plugin, UserId, DeviceId);
100 if (auto Err = Device->init()) {
101 [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
102 DP("Failed to init device %d: %s\n", DeviceId, InfoMsg.c_str());
103 return false;
106 ExclusiveDevicesAccessor->push_back(std::move(Device));
108 // We need to map between the plugin's device identifier and the one
109 // that OpenMP will use.
110 PM->DeviceIds[std::make_pair(&Plugin, DeviceId)] = UserId;
112 return true;
115 void PluginManager::initializeAllDevices() {
116 for (auto &Plugin : plugins()) {
117 if (!initializePlugin(Plugin))
118 continue;
120 for (int32_t DeviceId = 0; DeviceId < Plugin.number_of_devices();
121 ++DeviceId) {
122 initializeDevice(Plugin, DeviceId);
127 void PluginManager::registerLib(__tgt_bin_desc *Desc) {
128 PM->RTLsMtx.lock();
130 // Add in all the OpenMP requirements associated with this binary.
131 for (__tgt_offload_entry &Entry :
132 llvm::make_range(Desc->HostEntriesBegin, Desc->HostEntriesEnd))
133 if (Entry.flags == OMP_REGISTER_REQUIRES)
134 PM->addRequirements(Entry.data);
136 // Extract the exectuable image and extra information if availible.
137 for (int32_t i = 0; i < Desc->NumDeviceImages; ++i)
138 PM->addDeviceImage(*Desc, Desc->DeviceImages[i]);
140 // Register the images with the RTLs that understand them, if any.
141 for (DeviceImageTy &DI : PM->deviceImages()) {
142 // Obtain the image and information that was previously extracted.
143 __tgt_device_image *Img = &DI.getExecutableImage();
145 GenericPluginTy *FoundRTL = nullptr;
147 // Scan the RTLs that have associated images until we find one that supports
148 // the current image.
149 for (auto &R : plugins()) {
150 if (!R.is_plugin_compatible(Img))
151 continue;
153 if (!initializePlugin(R))
154 continue;
156 if (!R.number_of_devices()) {
157 DP("Skipping plugin %s with no visible devices\n", R.getName());
158 continue;
161 for (int32_t DeviceId = 0; DeviceId < R.number_of_devices(); ++DeviceId) {
162 if (!R.is_device_compatible(DeviceId, Img))
163 continue;
165 DP("Image " DPxMOD " is compatible with RTL %s device %d!\n",
166 DPxPTR(Img->ImageStart), R.getName(), DeviceId);
168 if (!initializeDevice(R, DeviceId))
169 continue;
171 // Initialize (if necessary) translation table for this library.
172 PM->TrlTblMtx.lock();
173 if (!PM->HostEntriesBeginToTransTable.count(Desc->HostEntriesBegin)) {
174 PM->HostEntriesBeginRegistrationOrder.push_back(
175 Desc->HostEntriesBegin);
176 TranslationTable &TT =
177 (PM->HostEntriesBeginToTransTable)[Desc->HostEntriesBegin];
178 TT.HostTable.EntriesBegin = Desc->HostEntriesBegin;
179 TT.HostTable.EntriesEnd = Desc->HostEntriesEnd;
182 // Retrieve translation table for this library.
183 TranslationTable &TT =
184 (PM->HostEntriesBeginToTransTable)[Desc->HostEntriesBegin];
186 DP("Registering image " DPxMOD " with RTL %s!\n",
187 DPxPTR(Img->ImageStart), R.getName());
189 auto UserId = PM->DeviceIds[std::make_pair(&R, DeviceId)];
190 if (TT.TargetsTable.size() < static_cast<size_t>(UserId + 1)) {
191 TT.DeviceTables.resize(UserId + 1, {});
192 TT.TargetsImages.resize(UserId + 1, nullptr);
193 TT.TargetsEntries.resize(UserId + 1, {});
194 TT.TargetsTable.resize(UserId + 1, nullptr);
197 // Register the image for this target type and invalidate the table.
198 TT.TargetsImages[UserId] = Img;
199 TT.TargetsTable[UserId] = nullptr;
201 PM->UsedImages.insert(Img);
202 FoundRTL = &R;
204 PM->TrlTblMtx.unlock();
207 if (!FoundRTL)
208 DP("No RTL found for image " DPxMOD "!\n", DPxPTR(Img->ImageStart));
210 PM->RTLsMtx.unlock();
212 bool UseAutoZeroCopy = Plugins.size() > 0;
214 auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
215 for (const auto &Device : *ExclusiveDevicesAccessor)
216 UseAutoZeroCopy &= Device->useAutoZeroCopy();
218 // Auto Zero-Copy can only be currently triggered when the system is an
219 // homogeneous APU architecture without attached discrete GPUs.
220 // If all devices suggest to use it, change requirment flags to trigger
221 // zero-copy behavior when mapping memory.
222 if (UseAutoZeroCopy)
223 addRequirements(OMPX_REQ_AUTO_ZERO_COPY);
225 DP("Done registering entries!\n");
228 // Temporary forward declaration, old style CTor/DTor handling is going away.
229 int target(ident_t *Loc, DeviceTy &Device, void *HostPtr,
230 KernelArgsTy &KernelArgs, AsyncInfoTy &AsyncInfo);
232 void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
233 DP("Unloading target library!\n");
235 PM->RTLsMtx.lock();
236 // Find which RTL understands each image, if any.
237 for (DeviceImageTy &DI : PM->deviceImages()) {
238 // Obtain the image and information that was previously extracted.
239 __tgt_device_image *Img = &DI.getExecutableImage();
241 GenericPluginTy *FoundRTL = NULL;
243 // Scan the RTLs that have associated images until we find one that supports
244 // the current image. We only need to scan RTLs that are already being used.
245 for (auto &R : plugins()) {
246 if (R.is_initialized())
247 continue;
249 // Ensure that we do not use any unused images associated with this RTL.
250 if (!UsedImages.contains(Img))
251 continue;
253 FoundRTL = &R;
255 DP("Unregistered image " DPxMOD " from RTL\n", DPxPTR(Img->ImageStart));
257 break;
260 // if no RTL was found proceed to unregister the next image
261 if (!FoundRTL) {
262 DP("No RTLs in use support the image " DPxMOD "!\n",
263 DPxPTR(Img->ImageStart));
266 PM->RTLsMtx.unlock();
267 DP("Done unregistering images!\n");
269 // Remove entries from PM->HostPtrToTableMap
270 PM->TblMapMtx.lock();
271 for (__tgt_offload_entry *Cur = Desc->HostEntriesBegin;
272 Cur < Desc->HostEntriesEnd; ++Cur) {
273 PM->HostPtrToTableMap.erase(Cur->addr);
276 // Remove translation table for this descriptor.
277 auto TransTable =
278 PM->HostEntriesBeginToTransTable.find(Desc->HostEntriesBegin);
279 if (TransTable != PM->HostEntriesBeginToTransTable.end()) {
280 DP("Removing translation table for descriptor " DPxMOD "\n",
281 DPxPTR(Desc->HostEntriesBegin));
282 PM->HostEntriesBeginToTransTable.erase(TransTable);
283 } else {
284 DP("Translation table for descriptor " DPxMOD " cannot be found, probably "
285 "it has been already removed.\n",
286 DPxPTR(Desc->HostEntriesBegin));
289 PM->TblMapMtx.unlock();
291 DP("Done unregistering library!\n");
294 /// Map global data and execute pending ctors
295 static int loadImagesOntoDevice(DeviceTy &Device) {
297 * Map global data
299 int32_t DeviceId = Device.DeviceID;
300 int Rc = OFFLOAD_SUCCESS;
302 std::lock_guard<decltype(PM->TrlTblMtx)> LG(PM->TrlTblMtx);
303 for (auto *HostEntriesBegin : PM->HostEntriesBeginRegistrationOrder) {
304 TranslationTable *TransTable =
305 &PM->HostEntriesBeginToTransTable[HostEntriesBegin];
306 DP("Trans table %p : %p\n", TransTable->HostTable.EntriesBegin,
307 TransTable->HostTable.EntriesEnd);
308 if (TransTable->HostTable.EntriesBegin ==
309 TransTable->HostTable.EntriesEnd) {
310 // No host entry so no need to proceed
311 continue;
314 if (TransTable->TargetsTable[DeviceId] != 0) {
315 // Library entries have already been processed
316 continue;
319 // 1) get image.
320 assert(TransTable->TargetsImages.size() > (size_t)DeviceId &&
321 "Not expecting a device ID outside the table's bounds!");
322 __tgt_device_image *Img = TransTable->TargetsImages[DeviceId];
323 if (!Img) {
324 REPORT("No image loaded for device id %d.\n", DeviceId);
325 Rc = OFFLOAD_FAIL;
326 break;
329 // 2) Load the image onto the given device.
330 auto BinaryOrErr = Device.loadBinary(Img);
331 if (llvm::Error Err = BinaryOrErr.takeError()) {
332 REPORT("Failed to load image %s\n",
333 llvm::toString(std::move(Err)).c_str());
334 Rc = OFFLOAD_FAIL;
335 break;
338 // 3) Create the translation table.
339 llvm::SmallVector<__tgt_offload_entry> &DeviceEntries =
340 TransTable->TargetsEntries[DeviceId];
341 for (__tgt_offload_entry &Entry :
342 llvm::make_range(Img->EntriesBegin, Img->EntriesEnd)) {
343 __tgt_device_binary &Binary = *BinaryOrErr;
345 __tgt_offload_entry DeviceEntry = Entry;
346 if (Entry.size) {
347 if (Device.RTL->get_global(Binary, Entry.size, Entry.name,
348 &DeviceEntry.addr) != OFFLOAD_SUCCESS)
349 REPORT("Failed to load symbol %s\n", Entry.name);
351 // If unified memory is active, the corresponding global is a device
352 // reference to the host global. We need to initialize the pointer on
353 // the device to point to the memory on the host.
354 if ((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) ||
355 (PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)) {
356 if (Device.RTL->data_submit(DeviceId, DeviceEntry.addr, Entry.addr,
357 Entry.size) != OFFLOAD_SUCCESS)
358 REPORT("Failed to write symbol for USM %s\n", Entry.name);
360 } else if (Entry.addr) {
361 if (Device.RTL->get_function(Binary, Entry.name, &DeviceEntry.addr) !=
362 OFFLOAD_SUCCESS)
363 REPORT("Failed to load kernel %s\n", Entry.name);
365 DP("Entry point " DPxMOD " maps to%s %s (" DPxMOD ")\n",
366 DPxPTR(Entry.addr), (Entry.size) ? " global" : "", Entry.name,
367 DPxPTR(DeviceEntry.addr));
369 DeviceEntries.emplace_back(DeviceEntry);
372 // Set the storage for the table and get a pointer to it.
373 __tgt_target_table DeviceTable{&DeviceEntries[0],
374 &DeviceEntries[0] + DeviceEntries.size()};
375 TransTable->DeviceTables[DeviceId] = DeviceTable;
376 __tgt_target_table *TargetTable = TransTable->TargetsTable[DeviceId] =
377 &TransTable->DeviceTables[DeviceId];
379 // 4) Verify whether the two table sizes match.
380 size_t Hsize =
381 TransTable->HostTable.EntriesEnd - TransTable->HostTable.EntriesBegin;
382 size_t Tsize = TargetTable->EntriesEnd - TargetTable->EntriesBegin;
384 // Invalid image for these host entries!
385 if (Hsize != Tsize) {
386 REPORT(
387 "Host and Target tables mismatch for device id %d [%zx != %zx].\n",
388 DeviceId, Hsize, Tsize);
389 TransTable->TargetsImages[DeviceId] = 0;
390 TransTable->TargetsTable[DeviceId] = 0;
391 Rc = OFFLOAD_FAIL;
392 break;
395 MappingInfoTy::HDTTMapAccessorTy HDTTMap =
396 Device.getMappingInfo().HostDataToTargetMap.getExclusiveAccessor();
398 __tgt_target_table *HostTable = &TransTable->HostTable;
399 for (__tgt_offload_entry *CurrDeviceEntry = TargetTable->EntriesBegin,
400 *CurrHostEntry = HostTable->EntriesBegin,
401 *EntryDeviceEnd = TargetTable->EntriesEnd;
402 CurrDeviceEntry != EntryDeviceEnd;
403 CurrDeviceEntry++, CurrHostEntry++) {
404 if (CurrDeviceEntry->size == 0)
405 continue;
407 assert(CurrDeviceEntry->size == CurrHostEntry->size &&
408 "data size mismatch");
410 // Fortran may use multiple weak declarations for the same symbol,
411 // therefore we must allow for multiple weak symbols to be loaded from
412 // the fat binary. Treat these mappings as any other "regular"
413 // mapping. Add entry to map.
414 if (Device.getMappingInfo().getTgtPtrBegin(HDTTMap, CurrHostEntry->addr,
415 CurrHostEntry->size))
416 continue;
418 void *CurrDeviceEntryAddr = CurrDeviceEntry->addr;
420 // For indirect mapping, follow the indirection and map the actual
421 // target.
422 if (CurrDeviceEntry->flags & OMP_DECLARE_TARGET_INDIRECT) {
423 AsyncInfoTy AsyncInfo(Device);
424 void *DevPtr;
425 Device.retrieveData(&DevPtr, CurrDeviceEntryAddr, sizeof(void *),
426 AsyncInfo, /*Entry=*/nullptr, &HDTTMap);
427 if (AsyncInfo.synchronize() != OFFLOAD_SUCCESS)
428 return OFFLOAD_FAIL;
429 CurrDeviceEntryAddr = DevPtr;
432 DP("Add mapping from host " DPxMOD " to device " DPxMOD " with size %zu"
433 ", name \"%s\"\n",
434 DPxPTR(CurrHostEntry->addr), DPxPTR(CurrDeviceEntry->addr),
435 CurrDeviceEntry->size, CurrDeviceEntry->name);
436 HDTTMap->emplace(new HostDataToTargetTy(
437 (uintptr_t)CurrHostEntry->addr /*HstPtrBase*/,
438 (uintptr_t)CurrHostEntry->addr /*HstPtrBegin*/,
439 (uintptr_t)CurrHostEntry->addr + CurrHostEntry->size /*HstPtrEnd*/,
440 (uintptr_t)CurrDeviceEntryAddr /*TgtAllocBegin*/,
441 (uintptr_t)CurrDeviceEntryAddr /*TgtPtrBegin*/,
442 false /*UseHoldRefCount*/, CurrHostEntry->name,
443 true /*IsRefCountINF*/));
445 // Notify about the new mapping.
446 if (Device.notifyDataMapped(CurrHostEntry->addr, CurrHostEntry->size))
447 return OFFLOAD_FAIL;
450 Device.setHasPendingImages(false);
453 if (Rc != OFFLOAD_SUCCESS)
454 return Rc;
456 static Int32Envar DumpOffloadEntries =
457 Int32Envar("OMPTARGET_DUMP_OFFLOAD_ENTRIES", -1);
458 if (DumpOffloadEntries.get() == DeviceId)
459 Device.dumpOffloadEntries();
461 return OFFLOAD_SUCCESS;
464 Expected<DeviceTy &> PluginManager::getDevice(uint32_t DeviceNo) {
465 DeviceTy *DevicePtr;
467 auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
468 if (DeviceNo >= ExclusiveDevicesAccessor->size())
469 return createStringError(
470 inconvertibleErrorCode(),
471 "Device number '%i' out of range, only %i devices available",
472 DeviceNo, ExclusiveDevicesAccessor->size());
474 DevicePtr = &*(*ExclusiveDevicesAccessor)[DeviceNo];
477 // Check whether global data has been mapped for this device
478 if (DevicePtr->hasPendingImages())
479 if (loadImagesOntoDevice(*DevicePtr) != OFFLOAD_SUCCESS)
480 return createStringError(inconvertibleErrorCode(),
481 "Failed to load images on device '%i'",
482 DeviceNo);
483 return *DevicePtr;