1 //===-------- OmptInterface.h - Target independent OpenMP target RTL ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Declarations for OpenMP Tool callback dispatchers
11 //===----------------------------------------------------------------------===//
13 #ifndef _OMPTARGET_OMPTINTERFACE_H
14 #define _OMPTARGET_OMPTINTERFACE_H
16 // Only provide functionality if target OMPT support is enabled
21 #include "OmptCallback.h"
22 #include "omp-tools.h"
24 #include "llvm/Support/ErrorHandling.h"
26 #define OMPT_IF_BUILT(stmt) stmt
27 #define OMPT_GET_RETURN_ADDRESS(level) __builtin_return_address(level)
29 /// Callbacks for target regions require task_data representing the
30 /// encountering task.
31 /// Callbacks for target regions and target data ops require
32 /// target_task_data representing the target task region.
33 typedef ompt_data_t
*(*ompt_get_task_data_t
)();
34 typedef ompt_data_t
*(*ompt_get_target_task_data_t
)();
41 /// Function pointers that will be used to track task_data and
43 static ompt_get_task_data_t ompt_get_task_data_fn
;
44 static ompt_get_target_task_data_t ompt_get_target_task_data_fn
;
46 /// Used to maintain execution state for this thread
49 /// Top-level function for invoking callback before device data allocation
50 void beginTargetDataAlloc(int64_t DeviceId
, void *HstPtrBegin
,
51 void **TgtPtrBegin
, size_t Size
, void *Code
);
53 /// Top-level function for invoking callback after device data allocation
54 void endTargetDataAlloc(int64_t DeviceId
, void *HstPtrBegin
,
55 void **TgtPtrBegin
, size_t Size
, void *Code
);
57 /// Top-level function for invoking callback before data submit
58 void beginTargetDataSubmit(int64_t DeviceId
, void *HstPtrBegin
,
59 void *TgtPtrBegin
, size_t Size
, void *Code
);
61 /// Top-level function for invoking callback after data submit
62 void endTargetDataSubmit(int64_t DeviceId
, void *HstPtrBegin
,
63 void *TgtPtrBegin
, size_t Size
, void *Code
);
65 /// Top-level function for invoking callback before device data deallocation
66 void beginTargetDataDelete(int64_t DeviceId
, void *TgtPtrBegin
, void *Code
);
68 /// Top-level function for invoking callback after device data deallocation
69 void endTargetDataDelete(int64_t DeviceId
, void *TgtPtrBegin
, void *Code
);
71 /// Top-level function for invoking callback before data retrieve
72 void beginTargetDataRetrieve(int64_t DeviceId
, void *HstPtrBegin
,
73 void *TgtPtrBegin
, size_t Size
, void *Code
);
75 /// Top-level function for invoking callback after data retrieve
76 void endTargetDataRetrieve(int64_t DeviceId
, void *HstPtrBegin
,
77 void *TgtPtrBegin
, size_t Size
, void *Code
);
79 /// Top-level function for invoking callback before kernel dispatch
80 void beginTargetSubmit(unsigned int NumTeams
= 1);
82 /// Top-level function for invoking callback after kernel dispatch
83 void endTargetSubmit(unsigned int NumTeams
= 1);
85 // Target region callbacks
87 /// Top-level function for invoking callback before target enter data
89 void beginTargetDataEnter(int64_t DeviceId
, void *Code
);
91 /// Top-level function for invoking callback after target enter data
93 void endTargetDataEnter(int64_t DeviceId
, void *Code
);
95 /// Top-level function for invoking callback before target exit data
97 void beginTargetDataExit(int64_t DeviceId
, void *Code
);
99 /// Top-level function for invoking callback after target exit data
101 void endTargetDataExit(int64_t DeviceId
, void *Code
);
103 /// Top-level function for invoking callback before target update construct
104 void beginTargetUpdate(int64_t DeviceId
, void *Code
);
106 /// Top-level function for invoking callback after target update construct
107 void endTargetUpdate(int64_t DeviceId
, void *Code
);
109 /// Top-level function for invoking callback before target construct
110 void beginTarget(int64_t DeviceId
, void *Code
);
112 /// Top-level function for invoking callback after target construct
113 void endTarget(int64_t DeviceId
, void *Code
);
115 // Callback getter: Target data operations
116 template <ompt_target_data_op_t OpType
> auto getCallbacks() {
117 if constexpr (OpType
== ompt_target_data_alloc
||
118 OpType
== ompt_target_data_alloc_async
)
119 return std::make_pair(std::mem_fn(&Interface::beginTargetDataAlloc
),
120 std::mem_fn(&Interface::endTargetDataAlloc
));
122 if constexpr (OpType
== ompt_target_data_delete
||
123 OpType
== ompt_target_data_delete_async
)
124 return std::make_pair(std::mem_fn(&Interface::beginTargetDataDelete
),
125 std::mem_fn(&Interface::endTargetDataDelete
));
127 if constexpr (OpType
== ompt_target_data_transfer_to_device
||
128 OpType
== ompt_target_data_transfer_to_device_async
)
129 return std::make_pair(std::mem_fn(&Interface::beginTargetDataSubmit
),
130 std::mem_fn(&Interface::endTargetDataSubmit
));
132 if constexpr (OpType
== ompt_target_data_transfer_from_device
||
133 OpType
== ompt_target_data_transfer_from_device_async
)
134 return std::make_pair(std::mem_fn(&Interface::beginTargetDataRetrieve
),
135 std::mem_fn(&Interface::endTargetDataRetrieve
));
137 llvm_unreachable("Unhandled target data operation type!");
140 // Callback getter: Target region operations
141 template <ompt_target_t OpType
> auto getCallbacks() {
142 if constexpr (OpType
== ompt_target_enter_data
||
143 OpType
== ompt_target_enter_data_nowait
)
144 return std::make_pair(std::mem_fn(&Interface::beginTargetDataEnter
),
145 std::mem_fn(&Interface::endTargetDataEnter
));
147 if constexpr (OpType
== ompt_target_exit_data
||
148 OpType
== ompt_target_exit_data_nowait
)
149 return std::make_pair(std::mem_fn(&Interface::beginTargetDataExit
),
150 std::mem_fn(&Interface::endTargetDataExit
));
152 if constexpr (OpType
== ompt_target_update
||
153 OpType
== ompt_target_update_nowait
)
154 return std::make_pair(std::mem_fn(&Interface::beginTargetUpdate
),
155 std::mem_fn(&Interface::endTargetUpdate
));
157 if constexpr (OpType
== ompt_target
|| OpType
== ompt_target_nowait
)
158 return std::make_pair(std::mem_fn(&Interface::beginTarget
),
159 std::mem_fn(&Interface::endTarget
));
161 llvm_unreachable("Unknown target region operation type!");
164 // Callback getter: Kernel launch operation
165 template <ompt_callbacks_t OpType
> auto getCallbacks() {
166 // We use 'ompt_callbacks_t', because no other enum is currently available
167 // to model a kernel launch / target submit operation.
168 if constexpr (OpType
== ompt_callback_target_submit
)
169 return std::make_pair(std::mem_fn(&Interface::beginTargetSubmit
),
170 std::mem_fn(&Interface::endTargetSubmit
));
172 llvm_unreachable("Unhandled target operation!");
175 /// Setters for target region and target operation correlation ids
176 void setTargetDataValue(uint64_t DataValue
) { TargetData
.value
= DataValue
; }
177 void setTargetDataPtr(void *DataPtr
) { TargetData
.ptr
= DataPtr
; }
178 void setHostOpId(ompt_id_t OpId
) { HostOpId
= OpId
; }
180 /// Getters for target region and target operation correlation ids
181 uint64_t getTargetDataValue() { return TargetData
.value
; }
182 void *getTargetDataPtr() { return TargetData
.ptr
; }
183 ompt_id_t
getHostOpId() { return HostOpId
; }
186 /// Target operations id
187 ompt_id_t HostOpId
= 0;
189 /// Target region data
190 ompt_data_t TargetData
= ompt_data_none
;
192 /// Task data representing the encountering task
193 ompt_data_t
*TaskData
= nullptr;
195 /// Target task data representing the target task region
196 ompt_data_t
*TargetTaskData
= nullptr;
198 /// Correlation id that is incremented with target operations
199 uint64_t TargetRegionOpId
= 1;
201 /// Used for marking begin of a data operation
202 void beginTargetDataOperation();
204 /// Used for marking end of a data operation
205 void endTargetDataOperation();
207 /// Used for marking begin of a target region
208 void beginTargetRegion();
210 /// Used for marking end of a target region
211 void endTargetRegion();
214 /// Thread local state for target region and associated metadata
215 extern thread_local Interface RegionInterface
;
217 template <typename FuncTy
, typename ArgsTy
, size_t... IndexSeq
>
218 void InvokeInterfaceFunction(FuncTy Func
, ArgsTy Args
,
219 std::index_sequence
<IndexSeq
...>) {
220 std::invoke(Func
, RegionInterface
, std::get
<IndexSeq
>(Args
)...);
223 template <typename CallbackPairTy
, typename
... ArgsTy
> class InterfaceRAII
{
225 InterfaceRAII(CallbackPairTy Callbacks
, ArgsTy
... Args
)
226 : Arguments(Args
...), beginFunction(std::get
<0>(Callbacks
)),
227 endFunction(std::get
<1>(Callbacks
)) {
228 performIfOmptInitialized(begin());
230 ~InterfaceRAII() { performIfOmptInitialized(end()); }
235 std::make_index_sequence
<std::tuple_size_v
<decltype(Arguments
)>>{};
236 InvokeInterfaceFunction(beginFunction
, Arguments
, IndexSequence
);
241 std::make_index_sequence
<std::tuple_size_v
<decltype(Arguments
)>>{};
242 InvokeInterfaceFunction(endFunction
, Arguments
, IndexSequence
);
245 std::tuple
<ArgsTy
...> Arguments
;
246 typename
CallbackPairTy::first_type beginFunction
;
247 typename
CallbackPairTy::second_type endFunction
;
250 // InterfaceRAII's class template argument deduction guide
251 template <typename CallbackPairTy
, typename
... ArgsTy
>
252 InterfaceRAII(CallbackPairTy Callbacks
, ArgsTy
... Args
)
253 -> InterfaceRAII
<CallbackPairTy
, ArgsTy
...>;
256 } // namespace target
260 #define OMPT_IF_BUILT(stmt)
263 #endif // _OMPTARGET_OMPTINTERFACE_H