Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / openmp / libomptarget / src / OmptInterface.h
blob178cedacf4a5860bb5374a0c53f9241308de6ba0
1 //===-------- OmptInterface.h - Target independent OpenMP target RTL ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Declarations for OpenMP Tool callback dispatchers
11 //===----------------------------------------------------------------------===//
13 #ifndef _OMPTARGET_OMPTINTERFACE_H
14 #define _OMPTARGET_OMPTINTERFACE_H
16 // Only provide functionality if target OMPT support is enabled
17 #ifdef OMPT_SUPPORT
18 #include <functional>
19 #include <tuple>
21 #include "OmptCallback.h"
22 #include "omp-tools.h"
24 #include "llvm/Support/ErrorHandling.h"
26 #define OMPT_IF_BUILT(stmt) stmt
27 #define OMPT_GET_RETURN_ADDRESS(level) __builtin_return_address(level)
29 /// Callbacks for target regions require task_data representing the
30 /// encountering task.
31 /// Callbacks for target regions and target data ops require
32 /// target_task_data representing the target task region.
33 typedef ompt_data_t *(*ompt_get_task_data_t)();
34 typedef ompt_data_t *(*ompt_get_target_task_data_t)();
36 namespace llvm {
37 namespace omp {
38 namespace target {
39 namespace ompt {
41 /// Function pointers that will be used to track task_data and
42 /// target_task_data.
43 static ompt_get_task_data_t ompt_get_task_data_fn;
44 static ompt_get_target_task_data_t ompt_get_target_task_data_fn;
46 /// Used to maintain execution state for this thread
47 class Interface {
48 public:
49 /// Top-level function for invoking callback before device data allocation
50 void beginTargetDataAlloc(int64_t DeviceId, void *HstPtrBegin,
51 void **TgtPtrBegin, size_t Size, void *Code);
53 /// Top-level function for invoking callback after device data allocation
54 void endTargetDataAlloc(int64_t DeviceId, void *HstPtrBegin,
55 void **TgtPtrBegin, size_t Size, void *Code);
57 /// Top-level function for invoking callback before data submit
58 void beginTargetDataSubmit(int64_t DeviceId, void *HstPtrBegin,
59 void *TgtPtrBegin, size_t Size, void *Code);
61 /// Top-level function for invoking callback after data submit
62 void endTargetDataSubmit(int64_t DeviceId, void *HstPtrBegin,
63 void *TgtPtrBegin, size_t Size, void *Code);
65 /// Top-level function for invoking callback before device data deallocation
66 void beginTargetDataDelete(int64_t DeviceId, void *TgtPtrBegin, void *Code);
68 /// Top-level function for invoking callback after device data deallocation
69 void endTargetDataDelete(int64_t DeviceId, void *TgtPtrBegin, void *Code);
71 /// Top-level function for invoking callback before data retrieve
72 void beginTargetDataRetrieve(int64_t DeviceId, void *HstPtrBegin,
73 void *TgtPtrBegin, size_t Size, void *Code);
75 /// Top-level function for invoking callback after data retrieve
76 void endTargetDataRetrieve(int64_t DeviceId, void *HstPtrBegin,
77 void *TgtPtrBegin, size_t Size, void *Code);
79 /// Top-level function for invoking callback before kernel dispatch
80 void beginTargetSubmit(unsigned int NumTeams = 1);
82 /// Top-level function for invoking callback after kernel dispatch
83 void endTargetSubmit(unsigned int NumTeams = 1);
85 // Target region callbacks
87 /// Top-level function for invoking callback before target enter data
88 /// construct
89 void beginTargetDataEnter(int64_t DeviceId, void *Code);
91 /// Top-level function for invoking callback after target enter data
92 /// construct
93 void endTargetDataEnter(int64_t DeviceId, void *Code);
95 /// Top-level function for invoking callback before target exit data
96 /// construct
97 void beginTargetDataExit(int64_t DeviceId, void *Code);
99 /// Top-level function for invoking callback after target exit data
100 /// construct
101 void endTargetDataExit(int64_t DeviceId, void *Code);
103 /// Top-level function for invoking callback before target update construct
104 void beginTargetUpdate(int64_t DeviceId, void *Code);
106 /// Top-level function for invoking callback after target update construct
107 void endTargetUpdate(int64_t DeviceId, void *Code);
109 /// Top-level function for invoking callback before target construct
110 void beginTarget(int64_t DeviceId, void *Code);
112 /// Top-level function for invoking callback after target construct
113 void endTarget(int64_t DeviceId, void *Code);
115 // Callback getter: Target data operations
116 template <ompt_target_data_op_t OpType> auto getCallbacks() {
117 if constexpr (OpType == ompt_target_data_alloc ||
118 OpType == ompt_target_data_alloc_async)
119 return std::make_pair(std::mem_fn(&Interface::beginTargetDataAlloc),
120 std::mem_fn(&Interface::endTargetDataAlloc));
122 if constexpr (OpType == ompt_target_data_delete ||
123 OpType == ompt_target_data_delete_async)
124 return std::make_pair(std::mem_fn(&Interface::beginTargetDataDelete),
125 std::mem_fn(&Interface::endTargetDataDelete));
127 if constexpr (OpType == ompt_target_data_transfer_to_device ||
128 OpType == ompt_target_data_transfer_to_device_async)
129 return std::make_pair(std::mem_fn(&Interface::beginTargetDataSubmit),
130 std::mem_fn(&Interface::endTargetDataSubmit));
132 if constexpr (OpType == ompt_target_data_transfer_from_device ||
133 OpType == ompt_target_data_transfer_from_device_async)
134 return std::make_pair(std::mem_fn(&Interface::beginTargetDataRetrieve),
135 std::mem_fn(&Interface::endTargetDataRetrieve));
137 llvm_unreachable("Unhandled target data operation type!");
140 // Callback getter: Target region operations
141 template <ompt_target_t OpType> auto getCallbacks() {
142 if constexpr (OpType == ompt_target_enter_data ||
143 OpType == ompt_target_enter_data_nowait)
144 return std::make_pair(std::mem_fn(&Interface::beginTargetDataEnter),
145 std::mem_fn(&Interface::endTargetDataEnter));
147 if constexpr (OpType == ompt_target_exit_data ||
148 OpType == ompt_target_exit_data_nowait)
149 return std::make_pair(std::mem_fn(&Interface::beginTargetDataExit),
150 std::mem_fn(&Interface::endTargetDataExit));
152 if constexpr (OpType == ompt_target_update ||
153 OpType == ompt_target_update_nowait)
154 return std::make_pair(std::mem_fn(&Interface::beginTargetUpdate),
155 std::mem_fn(&Interface::endTargetUpdate));
157 if constexpr (OpType == ompt_target || OpType == ompt_target_nowait)
158 return std::make_pair(std::mem_fn(&Interface::beginTarget),
159 std::mem_fn(&Interface::endTarget));
161 llvm_unreachable("Unknown target region operation type!");
164 // Callback getter: Kernel launch operation
165 template <ompt_callbacks_t OpType> auto getCallbacks() {
166 // We use 'ompt_callbacks_t', because no other enum is currently available
167 // to model a kernel launch / target submit operation.
168 if constexpr (OpType == ompt_callback_target_submit)
169 return std::make_pair(std::mem_fn(&Interface::beginTargetSubmit),
170 std::mem_fn(&Interface::endTargetSubmit));
172 llvm_unreachable("Unhandled target operation!");
175 /// Setters for target region and target operation correlation ids
176 void setTargetDataValue(uint64_t DataValue) { TargetData.value = DataValue; }
177 void setTargetDataPtr(void *DataPtr) { TargetData.ptr = DataPtr; }
178 void setHostOpId(ompt_id_t OpId) { HostOpId = OpId; }
180 /// Getters for target region and target operation correlation ids
181 uint64_t getTargetDataValue() { return TargetData.value; }
182 void *getTargetDataPtr() { return TargetData.ptr; }
183 ompt_id_t getHostOpId() { return HostOpId; }
185 private:
186 /// Target operations id
187 ompt_id_t HostOpId = 0;
189 /// Target region data
190 ompt_data_t TargetData = ompt_data_none;
192 /// Task data representing the encountering task
193 ompt_data_t *TaskData = nullptr;
195 /// Target task data representing the target task region
196 ompt_data_t *TargetTaskData = nullptr;
198 /// Correlation id that is incremented with target operations
199 uint64_t TargetRegionOpId = 1;
201 /// Used for marking begin of a data operation
202 void beginTargetDataOperation();
204 /// Used for marking end of a data operation
205 void endTargetDataOperation();
207 /// Used for marking begin of a target region
208 void beginTargetRegion();
210 /// Used for marking end of a target region
211 void endTargetRegion();
214 /// Thread local state for target region and associated metadata
215 extern thread_local Interface RegionInterface;
217 template <typename FuncTy, typename ArgsTy, size_t... IndexSeq>
218 void InvokeInterfaceFunction(FuncTy Func, ArgsTy Args,
219 std::index_sequence<IndexSeq...>) {
220 std::invoke(Func, RegionInterface, std::get<IndexSeq>(Args)...);
223 template <typename CallbackPairTy, typename... ArgsTy> class InterfaceRAII {
224 public:
225 InterfaceRAII(CallbackPairTy Callbacks, ArgsTy... Args)
226 : Arguments(Args...), beginFunction(std::get<0>(Callbacks)),
227 endFunction(std::get<1>(Callbacks)) {
228 performIfOmptInitialized(begin());
230 ~InterfaceRAII() { performIfOmptInitialized(end()); }
232 private:
233 void begin() {
234 auto IndexSequence =
235 std::make_index_sequence<std::tuple_size_v<decltype(Arguments)>>{};
236 InvokeInterfaceFunction(beginFunction, Arguments, IndexSequence);
239 void end() {
240 auto IndexSequence =
241 std::make_index_sequence<std::tuple_size_v<decltype(Arguments)>>{};
242 InvokeInterfaceFunction(endFunction, Arguments, IndexSequence);
245 std::tuple<ArgsTy...> Arguments;
246 typename CallbackPairTy::first_type beginFunction;
247 typename CallbackPairTy::second_type endFunction;
250 // InterfaceRAII's class template argument deduction guide
251 template <typename CallbackPairTy, typename... ArgsTy>
252 InterfaceRAII(CallbackPairTy Callbacks, ArgsTy... Args)
253 -> InterfaceRAII<CallbackPairTy, ArgsTy...>;
255 } // namespace ompt
256 } // namespace target
257 } // namespace omp
258 } // namespace llvm
259 #else
260 #define OMPT_IF_BUILT(stmt)
261 #endif
263 #endif // _OMPTARGET_OMPTINTERFACE_H