1 //===-- runtime/CUDA/descriptor.cpp ---------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Runtime/CUDA/descriptor.h"
10 #include "../terminator.h"
11 #include "flang/Runtime/CUDA/allocator.h"
12 #include "flang/Runtime/CUDA/common.h"
14 #include "cuda_runtime.h"
16 namespace Fortran::runtime::cuda
{
18 RT_EXT_API_GROUP_BEGIN
20 Descriptor
*RTDEF(CUFAllocDesciptor
)(
21 std::size_t sizeInBytes
, const char *sourceFile
, int sourceLine
) {
22 return reinterpret_cast<Descriptor
*>(CUFAllocManaged(sizeInBytes
));
25 void RTDEF(CUFFreeDesciptor
)(
26 Descriptor
*desc
, const char *sourceFile
, int sourceLine
) {
27 CUFFreeManaged(reinterpret_cast<void *>(desc
));
30 void *RTDEF(CUFGetDeviceAddress
)(
31 void *hostPtr
, const char *sourceFile
, int sourceLine
) {
32 Terminator terminator
{sourceFile
, sourceLine
};
34 CUDA_REPORT_IF_ERROR(cudaGetSymbolAddress((void **)&p
, hostPtr
));
36 terminator
.Crash("Could not retrieve symbol's address");
41 void RTDEF(CUFDescriptorSync
)(Descriptor
*dst
, const Descriptor
*src
,
42 const char *sourceFile
, int sourceLine
) {
43 std::size_t count
{src
->SizeInBytes()};
44 CUDA_REPORT_IF_ERROR(cudaMemcpy(
45 (void *)dst
, (const void *)src
, count
, cudaMemcpyHostToDevice
));
50 } // namespace Fortran::runtime::cuda