1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // OpenMP specific optimizations:
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
18 //===----------------------------------------------------------------------===//
20 #include "llvm/Transforms/IPO/OpenMPOpt.h"
22 #include "llvm/ADT/EnumeratedArray.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/Analysis/CallGraph.h"
26 #include "llvm/Analysis/CallGraphSCCPass.h"
27 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/Frontend/OpenMP/OMPConstants.h"
30 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31 #include "llvm/IR/Assumptions.h"
32 #include "llvm/IR/DiagnosticInfo.h"
33 #include "llvm/IR/GlobalValue.h"
34 #include "llvm/IR/Instruction.h"
35 #include "llvm/IR/IntrinsicInst.h"
36 #include "llvm/InitializePasses.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Transforms/IPO.h"
39 #include "llvm/Transforms/IPO/Attributor.h"
40 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
41 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
42 #include "llvm/Transforms/Utils/CodeExtractor.h"
47 #define DEBUG_TYPE "openmp-opt"
49 static cl::opt
<bool> DisableOpenMPOptimizations(
50 "openmp-opt-disable", cl::ZeroOrMore
,
51 cl::desc("Disable OpenMP specific optimizations."), cl::Hidden
,
54 static cl::opt
<bool> EnableParallelRegionMerging(
55 "openmp-opt-enable-merging", cl::ZeroOrMore
,
56 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden
,
60 DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore
,
61 cl::desc("Disable function internalization."),
62 cl::Hidden
, cl::init(false));
64 static cl::opt
<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
66 static cl::opt
<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
67 cl::init(false), cl::Hidden
);
69 static cl::opt
<bool> HideMemoryTransferLatency(
70 "openmp-hide-memory-transfer-latency",
71 cl::desc("[WIP] Tries to hide the latency of host to device memory"
73 cl::Hidden
, cl::init(false));
75 static cl::opt
<bool> DisableOpenMPOptDeglobalization(
76 "openmp-opt-disable-deglobalization", cl::ZeroOrMore
,
77 cl::desc("Disable OpenMP optimizations involving deglobalization."),
78 cl::Hidden
, cl::init(false));
80 static cl::opt
<bool> DisableOpenMPOptSPMDization(
81 "openmp-opt-disable-spmdization", cl::ZeroOrMore
,
82 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
83 cl::Hidden
, cl::init(false));
85 static cl::opt
<bool> DisableOpenMPOptFolding(
86 "openmp-opt-disable-folding", cl::ZeroOrMore
,
87 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden
,
90 static cl::opt
<bool> DisableOpenMPOptStateMachineRewrite(
91 "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore
,
92 cl::desc("Disable OpenMP optimizations that replace the state machine."),
93 cl::Hidden
, cl::init(false));
95 static cl::opt
<bool> PrintModuleAfterOptimizations(
96 "openmp-opt-print-module", cl::ZeroOrMore
,
97 cl::desc("Print the current module after OpenMP optimizations."),
98 cl::Hidden
, cl::init(false));
100 STATISTIC(NumOpenMPRuntimeCallsDeduplicated
,
101 "Number of OpenMP runtime calls deduplicated");
102 STATISTIC(NumOpenMPParallelRegionsDeleted
,
103 "Number of OpenMP parallel regions deleted");
104 STATISTIC(NumOpenMPRuntimeFunctionsIdentified
,
105 "Number of OpenMP runtime functions identified");
106 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified
,
107 "Number of OpenMP runtime function uses identified");
108 STATISTIC(NumOpenMPTargetRegionKernels
,
109 "Number of OpenMP target region entry points (=kernels) identified");
110 STATISTIC(NumOpenMPTargetRegionKernelsSPMD
,
111 "Number of OpenMP target region entry points (=kernels) executed in "
112 "SPMD-mode instead of generic-mode");
113 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine
,
114 "Number of OpenMP target region entry points (=kernels) executed in "
115 "generic-mode without a state machines");
116 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback
,
117 "Number of OpenMP target region entry points (=kernels) executed in "
118 "generic-mode with customized state machines with fallback");
119 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback
,
120 "Number of OpenMP target region entry points (=kernels) executed in "
121 "generic-mode with customized state machines without fallback");
123 NumOpenMPParallelRegionsReplacedInGPUStateMachine
,
124 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
125 STATISTIC(NumOpenMPParallelRegionsMerged
,
126 "Number of OpenMP parallel regions merged");
127 STATISTIC(NumBytesMovedToSharedMemory
,
128 "Amount of memory pushed to shared memory");
131 static constexpr auto TAG
= "[" DEBUG_TYPE
"]";
136 enum class AddressSpace
: unsigned {
144 struct AAHeapToShared
;
148 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
150 struct OMPInformationCache
: public InformationCache
{
151 OMPInformationCache(Module
&M
, AnalysisGetter
&AG
,
152 BumpPtrAllocator
&Allocator
, SetVector
<Function
*> &CGSCC
,
153 SmallPtrSetImpl
<Kernel
> &Kernels
)
154 : InformationCache(M
, AG
, Allocator
, &CGSCC
), OMPBuilder(M
),
157 OMPBuilder
.initialize();
158 initializeRuntimeFunctions();
159 initializeInternalControlVars();
162 /// Generic information that describes an internal control variable.
163 struct InternalControlVarInfo
{
164 /// The kind, as described by InternalControlVar enum.
165 InternalControlVar Kind
;
167 /// The name of the ICV.
170 /// Environment variable associated with this ICV.
171 StringRef EnvVarName
;
173 /// Initial value kind.
174 ICVInitValue InitKind
;
177 ConstantInt
*InitValue
;
179 /// Setter RTL function associated with this ICV.
180 RuntimeFunction Setter
;
182 /// Getter RTL function associated with this ICV.
183 RuntimeFunction Getter
;
185 /// RTL Function corresponding to the override clause of this ICV
186 RuntimeFunction Clause
;
189 /// Generic information that describes a runtime function
190 struct RuntimeFunctionInfo
{
192 /// The kind, as described by the RuntimeFunction enum.
193 RuntimeFunction Kind
;
195 /// The name of the function.
198 /// Flag to indicate a variadic function.
201 /// The return type of the function.
204 /// The argument types of the function.
205 SmallVector
<Type
*, 8> ArgumentTypes
;
207 /// The declaration if available.
208 Function
*Declaration
= nullptr;
210 /// Uses of this runtime function per function containing the use.
211 using UseVector
= SmallVector
<Use
*, 16>;
213 /// Clear UsesMap for runtime function.
214 void clearUsesMap() { UsesMap
.clear(); }
216 /// Boolean conversion that is true if the runtime function was found.
217 operator bool() const { return Declaration
; }
219 /// Return the vector of uses in function \p F.
220 UseVector
&getOrCreateUseVector(Function
*F
) {
221 std::shared_ptr
<UseVector
> &UV
= UsesMap
[F
];
223 UV
= std::make_shared
<UseVector
>();
227 /// Return the vector of uses in function \p F or `nullptr` if there are
229 const UseVector
*getUseVector(Function
&F
) const {
230 auto I
= UsesMap
.find(&F
);
231 if (I
!= UsesMap
.end())
232 return I
->second
.get();
236 /// Return how many functions contain uses of this runtime function.
237 size_t getNumFunctionsWithUses() const { return UsesMap
.size(); }
239 /// Return the number of arguments (or the minimal number for variadic
241 size_t getNumArgs() const { return ArgumentTypes
.size(); }
243 /// Run the callback \p CB on each use and forget the use if the result is
244 /// true. The callback will be fed the function in which the use was
245 /// encountered as second argument.
246 void foreachUse(SmallVectorImpl
<Function
*> &SCC
,
247 function_ref
<bool(Use
&, Function
&)> CB
) {
248 for (Function
*F
: SCC
)
252 /// Run the callback \p CB on each use within the function \p F and forget
253 /// the use if the result is true.
254 void foreachUse(function_ref
<bool(Use
&, Function
&)> CB
, Function
*F
) {
255 SmallVector
<unsigned, 8> ToBeDeleted
;
259 UseVector
&UV
= getOrCreateUseVector(F
);
263 ToBeDeleted
.push_back(Idx
);
267 // Remove the to-be-deleted indices in reverse order as prior
268 // modifications will not modify the smaller indices.
269 while (!ToBeDeleted
.empty()) {
270 unsigned Idx
= ToBeDeleted
.pop_back_val();
277 /// Map from functions to all uses of this runtime function contained in
279 DenseMap
<Function
*, std::shared_ptr
<UseVector
>> UsesMap
;
282 /// Iterators for the uses of this runtime function.
283 decltype(UsesMap
)::iterator
begin() { return UsesMap
.begin(); }
284 decltype(UsesMap
)::iterator
end() { return UsesMap
.end(); }
287 /// An OpenMP-IR-Builder instance
288 OpenMPIRBuilder OMPBuilder
;
290 /// Map from runtime function kind to the runtime function description.
291 EnumeratedArray
<RuntimeFunctionInfo
, RuntimeFunction
,
292 RuntimeFunction::OMPRTL___last
>
295 /// Map from function declarations/definitions to their runtime enum type.
296 DenseMap
<Function
*, RuntimeFunction
> RuntimeFunctionIDMap
;
298 /// Map from ICV kind to the ICV description.
299 EnumeratedArray
<InternalControlVarInfo
, InternalControlVar
,
300 InternalControlVar::ICV___last
>
303 /// Helper to initialize all internal control variable information for those
304 /// defined in OMPKinds.def.
305 void initializeInternalControlVars() {
306 #define ICV_RT_SET(_Name, RTL) \
308 auto &ICV = ICVs[_Name]; \
311 #define ICV_RT_GET(Name, RTL) \
313 auto &ICV = ICVs[Name]; \
316 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
318 auto &ICV = ICVs[Enum]; \
321 ICV.InitKind = Init; \
322 ICV.EnvVarName = _EnvVarName; \
323 switch (ICV.InitKind) { \
324 case ICV_IMPLEMENTATION_DEFINED: \
325 ICV.InitValue = nullptr; \
328 ICV.InitValue = ConstantInt::get( \
329 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
332 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
338 #include "llvm/Frontend/OpenMP/OMPKinds.def"
341 /// Returns true if the function declaration \p F matches the runtime
342 /// function types, that is, return type \p RTFRetType, and argument types
344 static bool declMatchesRTFTypes(Function
*F
, Type
*RTFRetType
,
345 SmallVector
<Type
*, 8> &RTFArgTypes
) {
346 // TODO: We should output information to the user (under debug output
351 if (F
->getReturnType() != RTFRetType
)
353 if (F
->arg_size() != RTFArgTypes
.size())
356 auto RTFTyIt
= RTFArgTypes
.begin();
357 for (Argument
&Arg
: F
->args()) {
358 if (Arg
.getType() != *RTFTyIt
)
367 // Helper to collect all uses of the declaration in the UsesMap.
368 unsigned collectUses(RuntimeFunctionInfo
&RFI
, bool CollectStats
= true) {
369 unsigned NumUses
= 0;
370 if (!RFI
.Declaration
)
372 OMPBuilder
.addAttributes(RFI
.Kind
, *RFI
.Declaration
);
375 NumOpenMPRuntimeFunctionsIdentified
+= 1;
376 NumOpenMPRuntimeFunctionUsesIdentified
+= RFI
.Declaration
->getNumUses();
379 // TODO: We directly convert uses into proper calls and unknown uses.
380 for (Use
&U
: RFI
.Declaration
->uses()) {
381 if (Instruction
*UserI
= dyn_cast
<Instruction
>(U
.getUser())) {
382 if (ModuleSlice
.count(UserI
->getFunction())) {
383 RFI
.getOrCreateUseVector(UserI
->getFunction()).push_back(&U
);
387 RFI
.getOrCreateUseVector(nullptr).push_back(&U
);
394 // Helper function to recollect uses of a runtime function.
395 void recollectUsesForFunction(RuntimeFunction RTF
) {
396 auto &RFI
= RFIs
[RTF
];
398 collectUses(RFI
, /*CollectStats*/ false);
401 // Helper function to recollect uses of all runtime functions.
402 void recollectUses() {
403 for (int Idx
= 0; Idx
< RFIs
.size(); ++Idx
)
404 recollectUsesForFunction(static_cast<RuntimeFunction
>(Idx
));
407 /// Helper to initialize all runtime function information for those defined
408 /// in OpenMPKinds.def.
409 void initializeRuntimeFunctions() {
410 Module
&M
= *((*ModuleSlice
.begin())->getParent());
412 // Helper macros for handling __VA_ARGS__ in OMP_RTL
413 #define OMP_TYPE(VarName, ...) \
414 Type *VarName = OMPBuilder.VarName; \
417 #define OMP_ARRAY_TYPE(VarName, ...) \
418 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
420 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
421 (void)VarName##PtrTy;
423 #define OMP_FUNCTION_TYPE(VarName, ...) \
424 FunctionType *VarName = OMPBuilder.VarName; \
426 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
429 #define OMP_STRUCT_TYPE(VarName, ...) \
430 StructType *VarName = OMPBuilder.VarName; \
432 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
435 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
437 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
438 Function *F = M.getFunction(_Name); \
439 RTLFunctions.insert(F); \
440 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
441 RuntimeFunctionIDMap[F] = _Enum; \
442 F->removeFnAttr(Attribute::NoInline); \
443 auto &RFI = RFIs[_Enum]; \
446 RFI.IsVarArg = _IsVarArg; \
447 RFI.ReturnType = OMPBuilder._ReturnType; \
448 RFI.ArgumentTypes = std::move(ArgsTypes); \
449 RFI.Declaration = F; \
450 unsigned NumUses = collectUses(RFI); \
453 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
455 if (RFI.Declaration) \
456 dbgs() << TAG << "-> got " << NumUses << " uses in " \
457 << RFI.getNumFunctionsWithUses() \
458 << " different functions.\n"; \
462 #include "llvm/Frontend/OpenMP/OMPKinds.def"
464 // TODO: We should attach the attributes defined in OMPKinds.def.
467 /// Collection of known kernels (\see Kernel) in the module.
468 SmallPtrSetImpl
<Kernel
> &Kernels
;
470 /// Collection of known OpenMP runtime functions..
471 DenseSet
<const Function
*> RTLFunctions
;
474 template <typename Ty
, bool InsertInvalidates
= true>
475 struct BooleanStateWithSetVector
: public BooleanState
{
476 bool contains(const Ty
&Elem
) const { return Set
.contains(Elem
); }
477 bool insert(const Ty
&Elem
) {
478 if (InsertInvalidates
)
479 BooleanState::indicatePessimisticFixpoint();
480 return Set
.insert(Elem
);
483 const Ty
&operator[](int Idx
) const { return Set
[Idx
]; }
484 bool operator==(const BooleanStateWithSetVector
&RHS
) const {
485 return BooleanState::operator==(RHS
) && Set
== RHS
.Set
;
487 bool operator!=(const BooleanStateWithSetVector
&RHS
) const {
488 return !(*this == RHS
);
491 bool empty() const { return Set
.empty(); }
492 size_t size() const { return Set
.size(); }
494 /// "Clamp" this state with \p RHS.
495 BooleanStateWithSetVector
&operator^=(const BooleanStateWithSetVector
&RHS
) {
496 BooleanState::operator^=(RHS
);
497 Set
.insert(RHS
.Set
.begin(), RHS
.Set
.end());
502 /// A set to keep track of elements.
506 typename
decltype(Set
)::iterator
begin() { return Set
.begin(); }
507 typename
decltype(Set
)::iterator
end() { return Set
.end(); }
508 typename
decltype(Set
)::const_iterator
begin() const { return Set
.begin(); }
509 typename
decltype(Set
)::const_iterator
end() const { return Set
.end(); }
512 template <typename Ty
, bool InsertInvalidates
= true>
513 using BooleanStateWithPtrSetVector
=
514 BooleanStateWithSetVector
<Ty
*, InsertInvalidates
>;
516 struct KernelInfoState
: AbstractState
{
517 /// Flag to track if we reached a fixpoint.
518 bool IsAtFixpoint
= false;
520 /// The parallel regions (identified by the outlined parallel functions) that
521 /// can be reached from the associated function.
522 BooleanStateWithPtrSetVector
<Function
, /* InsertInvalidates */ false>
523 ReachedKnownParallelRegions
;
525 /// State to track what parallel region we might reach.
526 BooleanStateWithPtrSetVector
<CallBase
> ReachedUnknownParallelRegions
;
528 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
529 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
531 BooleanStateWithPtrSetVector
<Instruction
, false> SPMDCompatibilityTracker
;
533 /// The __kmpc_target_init call in this kernel, if any. If we find more than
534 /// one we abort as the kernel is malformed.
535 CallBase
*KernelInitCB
= nullptr;
537 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
538 /// one we abort as the kernel is malformed.
539 CallBase
*KernelDeinitCB
= nullptr;
541 /// Flag to indicate if the associated function is a kernel entry.
542 bool IsKernelEntry
= false;
544 /// State to track what kernel entries can reach the associated function.
545 BooleanStateWithPtrSetVector
<Function
, false> ReachingKernelEntries
;
547 /// State to indicate if we can track parallel level of the associated
548 /// function. We will give up tracking if we encounter unknown caller or the
549 /// caller is __kmpc_parallel_51.
550 BooleanStateWithSetVector
<uint8_t> ParallelLevels
;
552 /// Abstract State interface
556 KernelInfoState(bool BestState
) {
558 indicatePessimisticFixpoint();
561 /// See AbstractState::isValidState(...)
562 bool isValidState() const override
{ return true; }
564 /// See AbstractState::isAtFixpoint(...)
565 bool isAtFixpoint() const override
{ return IsAtFixpoint
; }
567 /// See AbstractState::indicatePessimisticFixpoint(...)
568 ChangeStatus
indicatePessimisticFixpoint() override
{
570 SPMDCompatibilityTracker
.indicatePessimisticFixpoint();
571 ReachedUnknownParallelRegions
.indicatePessimisticFixpoint();
572 return ChangeStatus::CHANGED
;
575 /// See AbstractState::indicateOptimisticFixpoint(...)
576 ChangeStatus
indicateOptimisticFixpoint() override
{
578 return ChangeStatus::UNCHANGED
;
581 /// Return the assumed state
582 KernelInfoState
&getAssumed() { return *this; }
583 const KernelInfoState
&getAssumed() const { return *this; }
585 bool operator==(const KernelInfoState
&RHS
) const {
586 if (SPMDCompatibilityTracker
!= RHS
.SPMDCompatibilityTracker
)
588 if (ReachedKnownParallelRegions
!= RHS
.ReachedKnownParallelRegions
)
590 if (ReachedUnknownParallelRegions
!= RHS
.ReachedUnknownParallelRegions
)
592 if (ReachingKernelEntries
!= RHS
.ReachingKernelEntries
)
597 /// Return empty set as the best state of potential values.
598 static KernelInfoState
getBestState() { return KernelInfoState(true); }
600 static KernelInfoState
getBestState(KernelInfoState
&KIS
) {
601 return getBestState();
604 /// Return full set as the worst state of potential values.
605 static KernelInfoState
getWorstState() { return KernelInfoState(false); }
607 /// "Clamp" this state with \p KIS.
608 KernelInfoState
operator^=(const KernelInfoState
&KIS
) {
609 // Do not merge two different _init and _deinit call sites.
610 if (KIS
.KernelInitCB
) {
611 if (KernelInitCB
&& KernelInitCB
!= KIS
.KernelInitCB
)
612 indicatePessimisticFixpoint();
613 KernelInitCB
= KIS
.KernelInitCB
;
615 if (KIS
.KernelDeinitCB
) {
616 if (KernelDeinitCB
&& KernelDeinitCB
!= KIS
.KernelDeinitCB
)
617 indicatePessimisticFixpoint();
618 KernelDeinitCB
= KIS
.KernelDeinitCB
;
620 SPMDCompatibilityTracker
^= KIS
.SPMDCompatibilityTracker
;
621 ReachedKnownParallelRegions
^= KIS
.ReachedKnownParallelRegions
;
622 ReachedUnknownParallelRegions
^= KIS
.ReachedUnknownParallelRegions
;
626 KernelInfoState
operator&=(const KernelInfoState
&KIS
) {
627 return (*this ^= KIS
);
633 /// Used to map the values physically (in the IR) stored in an offload
634 /// array, to a vector in memory.
635 struct OffloadArray
{
636 /// Physical array (in the IR).
637 AllocaInst
*Array
= nullptr;
639 SmallVector
<Value
*, 8> StoredValues
;
640 /// Last stores made in the offload array.
641 SmallVector
<StoreInst
*, 8> LastAccesses
;
643 OffloadArray() = default;
645 /// Initializes the OffloadArray with the values stored in \p Array before
646 /// instruction \p Before is reached. Returns false if the initialization
648 /// This MUST be used immediately after the construction of the object.
649 bool initialize(AllocaInst
&Array
, Instruction
&Before
) {
650 if (!Array
.getAllocatedType()->isArrayTy())
653 if (!getValues(Array
, Before
))
656 this->Array
= &Array
;
660 static const unsigned DeviceIDArgNum
= 1;
661 static const unsigned BasePtrsArgNum
= 3;
662 static const unsigned PtrsArgNum
= 4;
663 static const unsigned SizesArgNum
= 5;
666 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
667 /// \p Array, leaving StoredValues with the values stored before the
668 /// instruction \p Before is reached.
669 bool getValues(AllocaInst
&Array
, Instruction
&Before
) {
670 // Initialize container.
671 const uint64_t NumValues
= Array
.getAllocatedType()->getArrayNumElements();
672 StoredValues
.assign(NumValues
, nullptr);
673 LastAccesses
.assign(NumValues
, nullptr);
675 // TODO: This assumes the instruction \p Before is in the same
676 // BasicBlock as Array. Make it general, for any control flow graph.
677 BasicBlock
*BB
= Array
.getParent();
678 if (BB
!= Before
.getParent())
681 const DataLayout
&DL
= Array
.getModule()->getDataLayout();
682 const unsigned int PointerSize
= DL
.getPointerSize();
684 for (Instruction
&I
: *BB
) {
688 if (!isa
<StoreInst
>(&I
))
691 auto *S
= cast
<StoreInst
>(&I
);
694 GetPointerBaseWithConstantOffset(S
->getPointerOperand(), Offset
, DL
);
696 int64_t Idx
= Offset
/ PointerSize
;
697 StoredValues
[Idx
] = getUnderlyingObject(S
->getValueOperand());
698 LastAccesses
[Idx
] = S
;
705 /// Returns true if all values in StoredValues and
706 /// LastAccesses are not nullptrs.
708 const unsigned NumValues
= StoredValues
.size();
709 for (unsigned I
= 0; I
< NumValues
; ++I
) {
710 if (!StoredValues
[I
] || !LastAccesses
[I
])
720 using OptimizationRemarkGetter
=
721 function_ref
<OptimizationRemarkEmitter
&(Function
*)>;
723 OpenMPOpt(SmallVectorImpl
<Function
*> &SCC
, CallGraphUpdater
&CGUpdater
,
724 OptimizationRemarkGetter OREGetter
,
725 OMPInformationCache
&OMPInfoCache
, Attributor
&A
)
726 : M(*(*SCC
.begin())->getParent()), SCC(SCC
), CGUpdater(CGUpdater
),
727 OREGetter(OREGetter
), OMPInfoCache(OMPInfoCache
), A(A
) {}
729 /// Check if any remarks are enabled for openmp-opt
730 bool remarksEnabled() {
731 auto &Ctx
= M
.getContext();
732 return Ctx
.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE
);
735 /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
736 bool run(bool IsModulePass
) {
740 bool Changed
= false;
742 LLVM_DEBUG(dbgs() << TAG
<< "Run on SCC with " << SCC
.size()
743 << " functions in a slice with "
744 << OMPInfoCache
.ModuleSlice
.size() << " functions\n");
747 Changed
|= runAttributor(IsModulePass
);
749 // Recollect uses, in case Attributor deleted any.
750 OMPInfoCache
.recollectUses();
752 // TODO: This should be folded into buildCustomStateMachine.
753 Changed
|= rewriteDeviceCodeStateMachine();
755 if (remarksEnabled())
756 analysisGlobalization();
760 if (PrintOpenMPKernels
)
763 Changed
|= runAttributor(IsModulePass
);
765 // Recollect uses, in case Attributor deleted any.
766 OMPInfoCache
.recollectUses();
768 Changed
|= deleteParallelRegions();
770 if (HideMemoryTransferLatency
)
771 Changed
|= hideMemTransfersLatency();
772 Changed
|= deduplicateRuntimeCalls();
773 if (EnableParallelRegionMerging
) {
774 if (mergeParallelRegions()) {
775 deduplicateRuntimeCalls();
784 /// Print initial ICV values for testing.
785 /// FIXME: This should be done from the Attributor once it is added.
786 void printICVs() const {
787 InternalControlVar ICVs
[] = {ICV_nthreads
, ICV_active_levels
, ICV_cancel
,
790 for (Function
*F
: OMPInfoCache
.ModuleSlice
) {
791 for (auto ICV
: ICVs
) {
792 auto ICVInfo
= OMPInfoCache
.ICVs
[ICV
];
793 auto Remark
= [&](OptimizationRemarkAnalysis ORA
) {
794 return ORA
<< "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo
.Name
)
796 << (ICVInfo
.InitValue
797 ? toString(ICVInfo
.InitValue
->getValue(), 10, true)
798 : "IMPLEMENTATION_DEFINED");
801 emitRemark
<OptimizationRemarkAnalysis
>(F
, "OpenMPICVTracker", Remark
);
806 /// Print OpenMP GPU kernels for testing.
807 void printKernels() const {
808 for (Function
*F
: SCC
) {
809 if (!OMPInfoCache
.Kernels
.count(F
))
812 auto Remark
= [&](OptimizationRemarkAnalysis ORA
) {
813 return ORA
<< "OpenMP GPU kernel "
814 << ore::NV("OpenMPGPUKernel", F
->getName()) << "\n";
817 emitRemark
<OptimizationRemarkAnalysis
>(F
, "OpenMPGPU", Remark
);
821 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
822 /// given it has to be the callee or a nullptr is returned.
823 static CallInst
*getCallIfRegularCall(
824 Use
&U
, OMPInformationCache::RuntimeFunctionInfo
*RFI
= nullptr) {
825 CallInst
*CI
= dyn_cast
<CallInst
>(U
.getUser());
826 if (CI
&& CI
->isCallee(&U
) && !CI
->hasOperandBundles() &&
828 (RFI
->Declaration
&& CI
->getCalledFunction() == RFI
->Declaration
)))
833 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
834 /// the callee or a nullptr is returned.
835 static CallInst
*getCallIfRegularCall(
836 Value
&V
, OMPInformationCache::RuntimeFunctionInfo
*RFI
= nullptr) {
837 CallInst
*CI
= dyn_cast
<CallInst
>(&V
);
838 if (CI
&& !CI
->hasOperandBundles() &&
840 (RFI
->Declaration
&& CI
->getCalledFunction() == RFI
->Declaration
)))
846 /// Merge parallel regions when it is safe.
847 bool mergeParallelRegions() {
848 const unsigned CallbackCalleeOperand
= 2;
849 const unsigned CallbackFirstArgOperand
= 3;
850 using InsertPointTy
= OpenMPIRBuilder::InsertPointTy
;
852 // Check if there are any __kmpc_fork_call calls to merge.
853 OMPInformationCache::RuntimeFunctionInfo
&RFI
=
854 OMPInfoCache
.RFIs
[OMPRTL___kmpc_fork_call
];
856 if (!RFI
.Declaration
)
859 // Unmergable calls that prevent merging a parallel region.
860 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo
[] = {
861 OMPInfoCache
.RFIs
[OMPRTL___kmpc_push_proc_bind
],
862 OMPInfoCache
.RFIs
[OMPRTL___kmpc_push_num_threads
],
865 bool Changed
= false;
866 LoopInfo
*LI
= nullptr;
867 DominatorTree
*DT
= nullptr;
869 SmallDenseMap
<BasicBlock
*, SmallPtrSet
<Instruction
*, 4>> BB2PRMap
;
871 BasicBlock
*StartBB
= nullptr, *EndBB
= nullptr;
872 auto BodyGenCB
= [&](InsertPointTy AllocaIP
, InsertPointTy CodeGenIP
,
873 BasicBlock
&ContinuationIP
) {
874 BasicBlock
*CGStartBB
= CodeGenIP
.getBlock();
875 BasicBlock
*CGEndBB
=
876 SplitBlock(CGStartBB
, &*CodeGenIP
.getPoint(), DT
, LI
);
877 assert(StartBB
!= nullptr && "StartBB should not be null");
878 CGStartBB
->getTerminator()->setSuccessor(0, StartBB
);
879 assert(EndBB
!= nullptr && "EndBB should not be null");
880 EndBB
->getTerminator()->setSuccessor(0, CGEndBB
);
883 auto PrivCB
= [&](InsertPointTy AllocaIP
, InsertPointTy CodeGenIP
, Value
&,
884 Value
&Inner
, Value
*&ReplacementValue
) -> InsertPointTy
{
885 ReplacementValue
= &Inner
;
889 auto FiniCB
= [&](InsertPointTy CodeGenIP
) {};
891 /// Create a sequential execution region within a merged parallel region,
892 /// encapsulated in a master construct with a barrier for synchronization.
893 auto CreateSequentialRegion
= [&](Function
*OuterFn
,
894 BasicBlock
*OuterPredBB
,
895 Instruction
*SeqStartI
,
896 Instruction
*SeqEndI
) {
897 // Isolate the instructions of the sequential region to a separate
899 BasicBlock
*ParentBB
= SeqStartI
->getParent();
900 BasicBlock
*SeqEndBB
=
901 SplitBlock(ParentBB
, SeqEndI
->getNextNode(), DT
, LI
);
902 BasicBlock
*SeqAfterBB
=
903 SplitBlock(SeqEndBB
, &*SeqEndBB
->getFirstInsertionPt(), DT
, LI
);
904 BasicBlock
*SeqStartBB
=
905 SplitBlock(ParentBB
, SeqStartI
, DT
, LI
, nullptr, "seq.par.merged");
907 assert(ParentBB
->getUniqueSuccessor() == SeqStartBB
&&
908 "Expected a different CFG");
909 const DebugLoc DL
= ParentBB
->getTerminator()->getDebugLoc();
910 ParentBB
->getTerminator()->eraseFromParent();
912 auto BodyGenCB
= [&](InsertPointTy AllocaIP
, InsertPointTy CodeGenIP
,
913 BasicBlock
&ContinuationIP
) {
914 BasicBlock
*CGStartBB
= CodeGenIP
.getBlock();
915 BasicBlock
*CGEndBB
=
916 SplitBlock(CGStartBB
, &*CodeGenIP
.getPoint(), DT
, LI
);
917 assert(SeqStartBB
!= nullptr && "SeqStartBB should not be null");
918 CGStartBB
->getTerminator()->setSuccessor(0, SeqStartBB
);
919 assert(SeqEndBB
!= nullptr && "SeqEndBB should not be null");
920 SeqEndBB
->getTerminator()->setSuccessor(0, CGEndBB
);
922 auto FiniCB
= [&](InsertPointTy CodeGenIP
) {};
924 // Find outputs from the sequential region to outside users and
925 // broadcast their values to them.
926 for (Instruction
&I
: *SeqStartBB
) {
927 SmallPtrSet
<Instruction
*, 4> OutsideUsers
;
928 for (User
*Usr
: I
.users()) {
929 Instruction
&UsrI
= *cast
<Instruction
>(Usr
);
930 // Ignore outputs to LT intrinsics, code extraction for the merged
931 // parallel region will fix them.
932 if (UsrI
.isLifetimeStartOrEnd())
935 if (UsrI
.getParent() != SeqStartBB
)
936 OutsideUsers
.insert(&UsrI
);
939 if (OutsideUsers
.empty())
942 // Emit an alloca in the outer region to store the broadcasted
944 const DataLayout
&DL
= M
.getDataLayout();
945 AllocaInst
*AllocaI
= new AllocaInst(
946 I
.getType(), DL
.getAllocaAddrSpace(), nullptr,
947 I
.getName() + ".seq.output.alloc", &OuterFn
->front().front());
949 // Emit a store instruction in the sequential BB to update the
951 new StoreInst(&I
, AllocaI
, SeqStartBB
->getTerminator());
953 // Emit a load instruction and replace the use of the output value
955 for (Instruction
*UsrI
: OutsideUsers
) {
956 LoadInst
*LoadI
= new LoadInst(
957 I
.getType(), AllocaI
, I
.getName() + ".seq.output.load", UsrI
);
958 UsrI
->replaceUsesOfWith(&I
, LoadI
);
962 OpenMPIRBuilder::LocationDescription
Loc(
963 InsertPointTy(ParentBB
, ParentBB
->end()), DL
);
964 InsertPointTy SeqAfterIP
=
965 OMPInfoCache
.OMPBuilder
.createMaster(Loc
, BodyGenCB
, FiniCB
);
967 OMPInfoCache
.OMPBuilder
.createBarrier(SeqAfterIP
, OMPD_parallel
);
969 BranchInst::Create(SeqAfterBB
, SeqAfterIP
.getBlock());
971 LLVM_DEBUG(dbgs() << TAG
<< "After sequential inlining " << *OuterFn
975 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
976 // contained in BB and only separated by instructions that can be
977 // redundantly executed in parallel. The block BB is split before the first
978 // call (in MergableCIs) and after the last so the entire region we merge
979 // into a single parallel region is contained in a single basic block
980 // without any other instructions. We use the OpenMPIRBuilder to outline
981 // that block and call the resulting function via __kmpc_fork_call.
982 auto Merge
= [&](SmallVectorImpl
<CallInst
*> &MergableCIs
, BasicBlock
*BB
) {
983 // TODO: Change the interface to allow single CIs expanded, e.g, to
984 // include an outer loop.
985 assert(MergableCIs
.size() > 1 && "Assumed multiple mergable CIs");
987 auto Remark
= [&](OptimizationRemark OR
) {
988 OR
<< "Parallel region merged with parallel region"
989 << (MergableCIs
.size() > 2 ? "s" : "") << " at ";
990 for (auto *CI
: llvm::drop_begin(MergableCIs
)) {
991 OR
<< ore::NV("OpenMPParallelMerge", CI
->getDebugLoc());
992 if (CI
!= MergableCIs
.back())
998 emitRemark
<OptimizationRemark
>(MergableCIs
.front(), "OMP150", Remark
);
1000 Function
*OriginalFn
= BB
->getParent();
1001 LLVM_DEBUG(dbgs() << TAG
<< "Merge " << MergableCIs
.size()
1002 << " parallel regions in " << OriginalFn
->getName()
1005 // Isolate the calls to merge in a separate block.
1006 EndBB
= SplitBlock(BB
, MergableCIs
.back()->getNextNode(), DT
, LI
);
1007 BasicBlock
*AfterBB
=
1008 SplitBlock(EndBB
, &*EndBB
->getFirstInsertionPt(), DT
, LI
);
1009 StartBB
= SplitBlock(BB
, MergableCIs
.front(), DT
, LI
, nullptr,
1012 assert(BB
->getUniqueSuccessor() == StartBB
&& "Expected a different CFG");
1013 const DebugLoc DL
= BB
->getTerminator()->getDebugLoc();
1014 BB
->getTerminator()->eraseFromParent();
1016 // Create sequential regions for sequential instructions that are
1017 // in-between mergable parallel regions.
1018 for (auto *It
= MergableCIs
.begin(), *End
= MergableCIs
.end() - 1;
1020 Instruction
*ForkCI
= *It
;
1021 Instruction
*NextForkCI
= *(It
+ 1);
1023 // Continue if there are not in-between instructions.
1024 if (ForkCI
->getNextNode() == NextForkCI
)
1027 CreateSequentialRegion(OriginalFn
, BB
, ForkCI
->getNextNode(),
1028 NextForkCI
->getPrevNode());
1031 OpenMPIRBuilder::LocationDescription
Loc(InsertPointTy(BB
, BB
->end()),
1033 IRBuilder
<>::InsertPoint
AllocaIP(
1034 &OriginalFn
->getEntryBlock(),
1035 OriginalFn
->getEntryBlock().getFirstInsertionPt());
1036 // Create the merged parallel region with default proc binding, to
1037 // avoid overriding binding settings, and without explicit cancellation.
1038 InsertPointTy AfterIP
= OMPInfoCache
.OMPBuilder
.createParallel(
1039 Loc
, AllocaIP
, BodyGenCB
, PrivCB
, FiniCB
, nullptr, nullptr,
1040 OMP_PROC_BIND_default
, /* IsCancellable */ false);
1041 BranchInst::Create(AfterBB
, AfterIP
.getBlock());
1043 // Perform the actual outlining.
1044 OMPInfoCache
.OMPBuilder
.finalize(OriginalFn
,
1045 /* AllowExtractorSinking */ true);
1047 Function
*OutlinedFn
= MergableCIs
.front()->getCaller();
1049 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1051 SmallVector
<Value
*, 8> Args
;
1052 for (auto *CI
: MergableCIs
) {
1054 CI
->getArgOperand(CallbackCalleeOperand
)->stripPointerCasts();
1056 cast
<FunctionType
>(Callee
->getType()->getPointerElementType());
1058 Args
.push_back(OutlinedFn
->getArg(0));
1059 Args
.push_back(OutlinedFn
->getArg(1));
1060 for (unsigned U
= CallbackFirstArgOperand
, E
= CI
->getNumArgOperands();
1062 Args
.push_back(CI
->getArgOperand(U
));
1064 CallInst
*NewCI
= CallInst::Create(FT
, Callee
, Args
, "", CI
);
1065 if (CI
->getDebugLoc())
1066 NewCI
->setDebugLoc(CI
->getDebugLoc());
1068 // Forward parameter attributes from the callback to the callee.
1069 for (unsigned U
= CallbackFirstArgOperand
, E
= CI
->getNumArgOperands();
1071 for (const Attribute
&A
: CI
->getAttributes().getParamAttrs(U
))
1072 NewCI
->addParamAttr(
1073 U
- (CallbackFirstArgOperand
- CallbackCalleeOperand
), A
);
1075 // Emit an explicit barrier to replace the implicit fork-join barrier.
1076 if (CI
!= MergableCIs
.back()) {
1077 // TODO: Remove barrier if the merged parallel region includes the
1079 OMPInfoCache
.OMPBuilder
.createBarrier(
1080 InsertPointTy(NewCI
->getParent(),
1081 NewCI
->getNextNode()->getIterator()),
1085 CI
->eraseFromParent();
1088 assert(OutlinedFn
!= OriginalFn
&& "Outlining failed");
1089 CGUpdater
.registerOutlinedFunction(*OriginalFn
, *OutlinedFn
);
1090 CGUpdater
.reanalyzeFunction(*OriginalFn
);
1092 NumOpenMPParallelRegionsMerged
+= MergableCIs
.size();
1097 // Helper function that identifes sequences of
1098 // __kmpc_fork_call uses in a basic block.
1099 auto DetectPRsCB
= [&](Use
&U
, Function
&F
) {
1100 CallInst
*CI
= getCallIfRegularCall(U
, &RFI
);
1101 BB2PRMap
[CI
->getParent()].insert(CI
);
1107 RFI
.foreachUse(SCC
, DetectPRsCB
);
1108 SmallVector
<SmallVector
<CallInst
*, 4>, 4> MergableCIsVector
;
1109 // Find mergable parallel regions within a basic block that are
1110 // safe to merge, that is any in-between instructions can safely
1111 // execute in parallel after merging.
1112 // TODO: support merging across basic-blocks.
1113 for (auto &It
: BB2PRMap
) {
1114 auto &CIs
= It
.getSecond();
1118 BasicBlock
*BB
= It
.getFirst();
1119 SmallVector
<CallInst
*, 4> MergableCIs
;
1121 /// Returns true if the instruction is mergable, false otherwise.
1122 /// A terminator instruction is unmergable by definition since merging
1123 /// works within a BB. Instructions before the mergable region are
1124 /// mergable if they are not calls to OpenMP runtime functions that may
1125 /// set different execution parameters for subsequent parallel regions.
1126 /// Instructions in-between parallel regions are mergable if they are not
1127 /// calls to any non-intrinsic function since that may call a non-mergable
1128 /// OpenMP runtime function.
1129 auto IsMergable
= [&](Instruction
&I
, bool IsBeforeMergableRegion
) {
1130 // We do not merge across BBs, hence return false (unmergable) if the
1131 // instruction is a terminator.
1132 if (I
.isTerminator())
1135 if (!isa
<CallInst
>(&I
))
1138 CallInst
*CI
= cast
<CallInst
>(&I
);
1139 if (IsBeforeMergableRegion
) {
1140 Function
*CalledFunction
= CI
->getCalledFunction();
1141 if (!CalledFunction
)
1143 // Return false (unmergable) if the call before the parallel
1144 // region calls an explicit affinity (proc_bind) or number of
1145 // threads (num_threads) compiler-generated function. Those settings
1146 // may be incompatible with following parallel regions.
1147 // TODO: ICV tracking to detect compatibility.
1148 for (const auto &RFI
: UnmergableCallsInfo
) {
1149 if (CalledFunction
== RFI
.Declaration
)
1153 // Return false (unmergable) if there is a call instruction
1154 // in-between parallel regions when it is not an intrinsic. It
1155 // may call an unmergable OpenMP runtime function in its callpath.
1156 // TODO: Keep track of possible OpenMP calls in the callpath.
1157 if (!isa
<IntrinsicInst
>(CI
))
1163 // Find maximal number of parallel region CIs that are safe to merge.
1164 for (auto It
= BB
->begin(), End
= BB
->end(); It
!= End
;) {
1165 Instruction
&I
= *It
;
1168 if (CIs
.count(&I
)) {
1169 MergableCIs
.push_back(cast
<CallInst
>(&I
));
1173 // Continue expanding if the instruction is mergable.
1174 if (IsMergable(I
, MergableCIs
.empty()))
1177 // Forward the instruction iterator to skip the next parallel region
1178 // since there is an unmergable instruction which can affect it.
1179 for (; It
!= End
; ++It
) {
1180 Instruction
&SkipI
= *It
;
1181 if (CIs
.count(&SkipI
)) {
1182 LLVM_DEBUG(dbgs() << TAG
<< "Skip parallel region " << SkipI
1183 << " due to " << I
<< "\n");
1189 // Store mergable regions found.
1190 if (MergableCIs
.size() > 1) {
1191 MergableCIsVector
.push_back(MergableCIs
);
1192 LLVM_DEBUG(dbgs() << TAG
<< "Found " << MergableCIs
.size()
1193 << " parallel regions in block " << BB
->getName()
1194 << " of function " << BB
->getParent()->getName()
1198 MergableCIs
.clear();
1201 if (!MergableCIsVector
.empty()) {
1204 for (auto &MergableCIs
: MergableCIsVector
)
1205 Merge(MergableCIs
, BB
);
1206 MergableCIsVector
.clear();
1211 /// Re-collect use for fork calls, emitted barrier calls, and
1212 /// any emitted master/end_master calls.
1213 OMPInfoCache
.recollectUsesForFunction(OMPRTL___kmpc_fork_call
);
1214 OMPInfoCache
.recollectUsesForFunction(OMPRTL___kmpc_barrier
);
1215 OMPInfoCache
.recollectUsesForFunction(OMPRTL___kmpc_master
);
1216 OMPInfoCache
.recollectUsesForFunction(OMPRTL___kmpc_end_master
);
1222 /// Try to delete parallel regions if possible.
1223 bool deleteParallelRegions() {
1224 const unsigned CallbackCalleeOperand
= 2;
1226 OMPInformationCache::RuntimeFunctionInfo
&RFI
=
1227 OMPInfoCache
.RFIs
[OMPRTL___kmpc_fork_call
];
1229 if (!RFI
.Declaration
)
1232 bool Changed
= false;
1233 auto DeleteCallCB
= [&](Use
&U
, Function
&) {
1234 CallInst
*CI
= getCallIfRegularCall(U
);
1237 auto *Fn
= dyn_cast
<Function
>(
1238 CI
->getArgOperand(CallbackCalleeOperand
)->stripPointerCasts());
1241 if (!Fn
->onlyReadsMemory())
1243 if (!Fn
->hasFnAttribute(Attribute::WillReturn
))
1246 LLVM_DEBUG(dbgs() << TAG
<< "Delete read-only parallel region in "
1247 << CI
->getCaller()->getName() << "\n");
1249 auto Remark
= [&](OptimizationRemark OR
) {
1250 return OR
<< "Removing parallel region with no side-effects.";
1252 emitRemark
<OptimizationRemark
>(CI
, "OMP160", Remark
);
1254 CGUpdater
.removeCallSite(*CI
);
1255 CI
->eraseFromParent();
1257 ++NumOpenMPParallelRegionsDeleted
;
1261 RFI
.foreachUse(SCC
, DeleteCallCB
);
1266 /// Try to eliminate runtime calls by reusing existing ones.
1267 bool deduplicateRuntimeCalls() {
1268 bool Changed
= false;
1270 RuntimeFunction DeduplicableRuntimeCallIDs
[] = {
1271 OMPRTL_omp_get_num_threads
,
1272 OMPRTL_omp_in_parallel
,
1273 OMPRTL_omp_get_cancellation
,
1274 OMPRTL_omp_get_thread_limit
,
1275 OMPRTL_omp_get_supported_active_levels
,
1276 OMPRTL_omp_get_level
,
1277 OMPRTL_omp_get_ancestor_thread_num
,
1278 OMPRTL_omp_get_team_size
,
1279 OMPRTL_omp_get_active_level
,
1280 OMPRTL_omp_in_final
,
1281 OMPRTL_omp_get_proc_bind
,
1282 OMPRTL_omp_get_num_places
,
1283 OMPRTL_omp_get_num_procs
,
1284 OMPRTL_omp_get_place_num
,
1285 OMPRTL_omp_get_partition_num_places
,
1286 OMPRTL_omp_get_partition_place_nums
};
1288 // Global-tid is handled separately.
1289 SmallSetVector
<Value
*, 16> GTIdArgs
;
1290 collectGlobalThreadIdArguments(GTIdArgs
);
1291 LLVM_DEBUG(dbgs() << TAG
<< "Found " << GTIdArgs
.size()
1292 << " global thread ID arguments\n");
1294 for (Function
*F
: SCC
) {
1295 for (auto DeduplicableRuntimeCallID
: DeduplicableRuntimeCallIDs
)
1296 Changed
|= deduplicateRuntimeCalls(
1297 *F
, OMPInfoCache
.RFIs
[DeduplicableRuntimeCallID
]);
1299 // __kmpc_global_thread_num is special as we can replace it with an
1300 // argument in enough cases to make it worth trying.
1301 Value
*GTIdArg
= nullptr;
1302 for (Argument
&Arg
: F
->args())
1303 if (GTIdArgs
.count(&Arg
)) {
1307 Changed
|= deduplicateRuntimeCalls(
1308 *F
, OMPInfoCache
.RFIs
[OMPRTL___kmpc_global_thread_num
], GTIdArg
);
1314 /// Tries to hide the latency of runtime calls that involve host to
1315 /// device memory transfers by splitting them into their "issue" and "wait"
1316 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1317 /// moved downards as much as possible. The "issue" issues the memory transfer
1318 /// asynchronously, returning a handle. The "wait" waits in the returned
1319 /// handle for the memory transfer to finish.
1320 bool hideMemTransfersLatency() {
1321 auto &RFI
= OMPInfoCache
.RFIs
[OMPRTL___tgt_target_data_begin_mapper
];
1322 bool Changed
= false;
1323 auto SplitMemTransfers
= [&](Use
&U
, Function
&Decl
) {
1324 auto *RTCall
= getCallIfRegularCall(U
, &RFI
);
1328 OffloadArray OffloadArrays
[3];
1329 if (!getValuesInOffloadArrays(*RTCall
, OffloadArrays
))
1332 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays
));
1334 // TODO: Check if can be moved upwards.
1335 bool WasSplit
= false;
1336 Instruction
*WaitMovementPoint
= canBeMovedDownwards(*RTCall
);
1337 if (WaitMovementPoint
)
1338 WasSplit
= splitTargetDataBeginRTC(*RTCall
, *WaitMovementPoint
);
1340 Changed
|= WasSplit
;
1343 RFI
.foreachUse(SCC
, SplitMemTransfers
);
1348 void analysisGlobalization() {
1349 auto &RFI
= OMPInfoCache
.RFIs
[OMPRTL___kmpc_alloc_shared
];
1351 auto CheckGlobalization
= [&](Use
&U
, Function
&Decl
) {
1352 if (CallInst
*CI
= getCallIfRegularCall(U
, &RFI
)) {
1353 auto Remark
= [&](OptimizationRemarkMissed ORM
) {
1355 << "Found thread data sharing on the GPU. "
1356 << "Expect degraded performance due to data globalization.";
1358 emitRemark
<OptimizationRemarkMissed
>(CI
, "OMP112", Remark
);
1364 RFI
.foreachUse(SCC
, CheckGlobalization
);
1367 /// Maps the values stored in the offload arrays passed as arguments to
1368 /// \p RuntimeCall into the offload arrays in \p OAs.
1369 bool getValuesInOffloadArrays(CallInst
&RuntimeCall
,
1370 MutableArrayRef
<OffloadArray
> OAs
) {
1371 assert(OAs
.size() == 3 && "Need space for three offload arrays!");
1373 // A runtime call that involves memory offloading looks something like:
1374 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1375 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1377 // So, the idea is to access the allocas that allocate space for these
1378 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1380 // i8** %offload_baseptrs.
1381 Value
*BasePtrsArg
=
1382 RuntimeCall
.getArgOperand(OffloadArray::BasePtrsArgNum
);
1383 // i8** %offload_ptrs.
1384 Value
*PtrsArg
= RuntimeCall
.getArgOperand(OffloadArray::PtrsArgNum
);
1385 // i8** %offload_sizes.
1386 Value
*SizesArg
= RuntimeCall
.getArgOperand(OffloadArray::SizesArgNum
);
1388 // Get values stored in **offload_baseptrs.
1389 auto *V
= getUnderlyingObject(BasePtrsArg
);
1390 if (!isa
<AllocaInst
>(V
))
1392 auto *BasePtrsArray
= cast
<AllocaInst
>(V
);
1393 if (!OAs
[0].initialize(*BasePtrsArray
, RuntimeCall
))
1396 // Get values stored in **offload_baseptrs.
1397 V
= getUnderlyingObject(PtrsArg
);
1398 if (!isa
<AllocaInst
>(V
))
1400 auto *PtrsArray
= cast
<AllocaInst
>(V
);
1401 if (!OAs
[1].initialize(*PtrsArray
, RuntimeCall
))
1404 // Get values stored in **offload_sizes.
1405 V
= getUnderlyingObject(SizesArg
);
1406 // If it's a [constant] global array don't analyze it.
1407 if (isa
<GlobalValue
>(V
))
1408 return isa
<Constant
>(V
);
1409 if (!isa
<AllocaInst
>(V
))
1412 auto *SizesArray
= cast
<AllocaInst
>(V
);
1413 if (!OAs
[2].initialize(*SizesArray
, RuntimeCall
))
1419 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1420 /// For now this is a way to test that the function getValuesInOffloadArrays
1421 /// is working properly.
1422 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1423 void dumpValuesInOffloadArrays(ArrayRef
<OffloadArray
> OAs
) {
1424 assert(OAs
.size() == 3 && "There are three offload arrays to debug!");
1426 LLVM_DEBUG(dbgs() << TAG
<< " Successfully got offload values:\n");
1427 std::string ValuesStr
;
1428 raw_string_ostream
Printer(ValuesStr
);
1429 std::string Separator
= " --- ";
1431 for (auto *BP
: OAs
[0].StoredValues
) {
1433 Printer
<< Separator
;
1435 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer
.str() << "\n");
1438 for (auto *P
: OAs
[1].StoredValues
) {
1440 Printer
<< Separator
;
1442 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer
.str() << "\n");
1445 for (auto *S
: OAs
[2].StoredValues
) {
1447 Printer
<< Separator
;
1449 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer
.str() << "\n");
1452 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1453 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1454 Instruction
*canBeMovedDownwards(CallInst
&RuntimeCall
) {
1455 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1456 // Make it traverse the CFG.
1458 Instruction
*CurrentI
= &RuntimeCall
;
1459 bool IsWorthIt
= false;
1460 while ((CurrentI
= CurrentI
->getNextNode())) {
1462 // TODO: Once we detect the regions to be offloaded we should use the
1463 // alias analysis manager to check if CurrentI may modify one of
1464 // the offloaded regions.
1465 if (CurrentI
->mayHaveSideEffects() || CurrentI
->mayReadFromMemory()) {
1472 // FIXME: For now if we move it over anything without side effect
1477 // Return end of BasicBlock.
1478 return RuntimeCall
.getParent()->getTerminator();
1481 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1482 bool splitTargetDataBeginRTC(CallInst
&RuntimeCall
,
1483 Instruction
&WaitMovementPoint
) {
1484 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1485 // function. Used for storing information of the async transfer, allowing to
1486 // wait on it later.
1487 auto &IRBuilder
= OMPInfoCache
.OMPBuilder
;
1488 auto *F
= RuntimeCall
.getCaller();
1489 Instruction
*FirstInst
= &(F
->getEntryBlock().front());
1490 AllocaInst
*Handle
= new AllocaInst(
1491 IRBuilder
.AsyncInfo
, F
->getAddressSpace(), "handle", FirstInst
);
1493 // Add "issue" runtime call declaration:
1494 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1495 // i8**, i8**, i64*, i64*)
1496 FunctionCallee IssueDecl
= IRBuilder
.getOrCreateRuntimeFunction(
1497 M
, OMPRTL___tgt_target_data_begin_mapper_issue
);
1499 // Change RuntimeCall call site for its asynchronous version.
1500 SmallVector
<Value
*, 16> Args
;
1501 for (auto &Arg
: RuntimeCall
.args())
1502 Args
.push_back(Arg
.get());
1503 Args
.push_back(Handle
);
1505 CallInst
*IssueCallsite
=
1506 CallInst::Create(IssueDecl
, Args
, /*NameStr=*/"", &RuntimeCall
);
1507 RuntimeCall
.eraseFromParent();
1509 // Add "wait" runtime call declaration:
1510 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1511 FunctionCallee WaitDecl
= IRBuilder
.getOrCreateRuntimeFunction(
1512 M
, OMPRTL___tgt_target_data_begin_mapper_wait
);
1514 Value
*WaitParams
[2] = {
1515 IssueCallsite
->getArgOperand(
1516 OffloadArray::DeviceIDArgNum
), // device_id.
1517 Handle
// handle to wait on.
1519 CallInst::Create(WaitDecl
, WaitParams
, /*NameStr=*/"", &WaitMovementPoint
);
1524 static Value
*combinedIdentStruct(Value
*CurrentIdent
, Value
*NextIdent
,
1525 bool GlobalOnly
, bool &SingleChoice
) {
1526 if (CurrentIdent
== NextIdent
)
1527 return CurrentIdent
;
1529 // TODO: Figure out how to actually combine multiple debug locations. For
1530 // now we just keep an existing one if there is a single choice.
1531 if (!GlobalOnly
|| isa
<GlobalValue
>(NextIdent
)) {
1532 SingleChoice
= !CurrentIdent
;
1538 /// Return an `struct ident_t*` value that represents the ones used in the
1539 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1540 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1541 /// return value we create one from scratch. We also do not yet combine
1542 /// information, e.g., the source locations, see combinedIdentStruct.
1544 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo
&RFI
,
1545 Function
&F
, bool GlobalOnly
) {
1546 bool SingleChoice
= true;
1547 Value
*Ident
= nullptr;
1548 auto CombineIdentStruct
= [&](Use
&U
, Function
&Caller
) {
1549 CallInst
*CI
= getCallIfRegularCall(U
, &RFI
);
1550 if (!CI
|| &F
!= &Caller
)
1552 Ident
= combinedIdentStruct(Ident
, CI
->getArgOperand(0),
1553 /* GlobalOnly */ true, SingleChoice
);
1556 RFI
.foreachUse(SCC
, CombineIdentStruct
);
1558 if (!Ident
|| !SingleChoice
) {
1559 // The IRBuilder uses the insertion block to get to the module, this is
1560 // unfortunate but we work around it for now.
1561 if (!OMPInfoCache
.OMPBuilder
.getInsertionPoint().getBlock())
1562 OMPInfoCache
.OMPBuilder
.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1563 &F
.getEntryBlock(), F
.getEntryBlock().begin()));
1564 // Create a fallback location if non was found.
1565 // TODO: Use the debug locations of the calls instead.
1566 Constant
*Loc
= OMPInfoCache
.OMPBuilder
.getOrCreateDefaultSrcLocStr();
1567 Ident
= OMPInfoCache
.OMPBuilder
.getOrCreateIdent(Loc
);
1572 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1573 /// \p ReplVal if given.
1574 bool deduplicateRuntimeCalls(Function
&F
,
1575 OMPInformationCache::RuntimeFunctionInfo
&RFI
,
1576 Value
*ReplVal
= nullptr) {
1577 auto *UV
= RFI
.getUseVector(F
);
1578 if (!UV
|| UV
->size() + (ReplVal
!= nullptr) < 2)
1582 dbgs() << TAG
<< "Deduplicate " << UV
->size() << " uses of " << RFI
.Name
1583 << (ReplVal
? " with an existing value\n" : "\n") << "\n");
1585 assert((!ReplVal
|| (isa
<Argument
>(ReplVal
) &&
1586 cast
<Argument
>(ReplVal
)->getParent() == &F
)) &&
1587 "Unexpected replacement value!");
1589 // TODO: Use dominance to find a good position instead.
1590 auto CanBeMoved
= [this](CallBase
&CB
) {
1591 unsigned NumArgs
= CB
.getNumArgOperands();
1594 if (CB
.getArgOperand(0)->getType() != OMPInfoCache
.OMPBuilder
.IdentPtr
)
1596 for (unsigned u
= 1; u
< NumArgs
; ++u
)
1597 if (isa
<Instruction
>(CB
.getArgOperand(u
)))
1604 if (CallInst
*CI
= getCallIfRegularCall(*U
, &RFI
)) {
1605 if (!CanBeMoved(*CI
))
1608 // If the function is a kernel, dedup will move
1609 // the runtime call right after the kernel init callsite. Otherwise,
1610 // it will move it to the beginning of the caller function.
1612 auto &KernelInitRFI
= OMPInfoCache
.RFIs
[OMPRTL___kmpc_target_init
];
1613 auto *KernelInitUV
= KernelInitRFI
.getUseVector(F
);
1615 if (KernelInitUV
->empty())
1618 assert(KernelInitUV
->size() == 1 &&
1619 "Expected a single __kmpc_target_init in kernel\n");
1621 CallInst
*KernelInitCI
=
1622 getCallIfRegularCall(*KernelInitUV
->front(), &KernelInitRFI
);
1623 assert(KernelInitCI
&&
1624 "Expected a call to __kmpc_target_init in kernel\n");
1626 CI
->moveAfter(KernelInitCI
);
1628 CI
->moveBefore(&*F
.getEntryBlock().getFirstInsertionPt());
1636 // If we use a call as a replacement value we need to make sure the ident is
1637 // valid at the new location. For now we just pick a global one, either
1638 // existing and used by one of the calls, or created from scratch.
1639 if (CallBase
*CI
= dyn_cast
<CallBase
>(ReplVal
)) {
1640 if (CI
->getNumArgOperands() > 0 &&
1641 CI
->getArgOperand(0)->getType() == OMPInfoCache
.OMPBuilder
.IdentPtr
) {
1642 Value
*Ident
= getCombinedIdentFromCallUsesIn(RFI
, F
,
1643 /* GlobalOnly */ true);
1644 CI
->setArgOperand(0, Ident
);
1648 bool Changed
= false;
1649 auto ReplaceAndDeleteCB
= [&](Use
&U
, Function
&Caller
) {
1650 CallInst
*CI
= getCallIfRegularCall(U
, &RFI
);
1651 if (!CI
|| CI
== ReplVal
|| &F
!= &Caller
)
1653 assert(CI
->getCaller() == &F
&& "Unexpected call!");
1655 auto Remark
= [&](OptimizationRemark OR
) {
1656 return OR
<< "OpenMP runtime call "
1657 << ore::NV("OpenMPOptRuntime", RFI
.Name
) << " deduplicated.";
1659 if (CI
->getDebugLoc())
1660 emitRemark
<OptimizationRemark
>(CI
, "OMP170", Remark
);
1662 emitRemark
<OptimizationRemark
>(&F
, "OMP170", Remark
);
1664 CGUpdater
.removeCallSite(*CI
);
1665 CI
->replaceAllUsesWith(ReplVal
);
1666 CI
->eraseFromParent();
1667 ++NumOpenMPRuntimeCallsDeduplicated
;
1671 RFI
.foreachUse(SCC
, ReplaceAndDeleteCB
);
1676 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1677 void collectGlobalThreadIdArguments(SmallSetVector
<Value
*, 16> >IdArgs
) {
1678 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1679 // initialization. We could define an AbstractAttribute instead and
1680 // run the Attributor here once it can be run as an SCC pass.
1682 // Helper to check the argument \p ArgNo at all call sites of \p F for
1684 auto CallArgOpIsGTId
= [&](Function
&F
, unsigned ArgNo
, CallInst
&RefCI
) {
1685 if (!F
.hasLocalLinkage())
1687 for (Use
&U
: F
.uses()) {
1688 if (CallInst
*CI
= getCallIfRegularCall(U
)) {
1689 Value
*ArgOp
= CI
->getArgOperand(ArgNo
);
1690 if (CI
== &RefCI
|| GTIdArgs
.count(ArgOp
) ||
1691 getCallIfRegularCall(
1692 *ArgOp
, &OMPInfoCache
.RFIs
[OMPRTL___kmpc_global_thread_num
]))
1700 // Helper to identify uses of a GTId as GTId arguments.
1701 auto AddUserArgs
= [&](Value
>Id
) {
1702 for (Use
&U
: GTId
.uses())
1703 if (CallInst
*CI
= dyn_cast
<CallInst
>(U
.getUser()))
1704 if (CI
->isArgOperand(&U
))
1705 if (Function
*Callee
= CI
->getCalledFunction())
1706 if (CallArgOpIsGTId(*Callee
, U
.getOperandNo(), *CI
))
1707 GTIdArgs
.insert(Callee
->getArg(U
.getOperandNo()));
1710 // The argument users of __kmpc_global_thread_num calls are GTIds.
1711 OMPInformationCache::RuntimeFunctionInfo
&GlobThreadNumRFI
=
1712 OMPInfoCache
.RFIs
[OMPRTL___kmpc_global_thread_num
];
1714 GlobThreadNumRFI
.foreachUse(SCC
, [&](Use
&U
, Function
&F
) {
1715 if (CallInst
*CI
= getCallIfRegularCall(U
, &GlobThreadNumRFI
))
1720 // Transitively search for more arguments by looking at the users of the
1721 // ones we know already. During the search the GTIdArgs vector is extended
1722 // so we cannot cache the size nor can we use a range based for.
1723 for (unsigned u
= 0; u
< GTIdArgs
.size(); ++u
)
1724 AddUserArgs(*GTIdArgs
[u
]);
1727 /// Kernel (=GPU) optimizations and utility functions
1731 /// Check if \p F is a kernel, hence entry point for target offloading.
1732 bool isKernel(Function
&F
) { return OMPInfoCache
.Kernels
.count(&F
); }
1734 /// Cache to remember the unique kernel for a function.
1735 DenseMap
<Function
*, Optional
<Kernel
>> UniqueKernelMap
;
1737 /// Find the unique kernel that will execute \p F, if any.
1738 Kernel
getUniqueKernelFor(Function
&F
);
1740 /// Find the unique kernel that will execute \p I, if any.
1741 Kernel
getUniqueKernelFor(Instruction
&I
) {
1742 return getUniqueKernelFor(*I
.getFunction());
1745 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1746 /// the cases we can avoid taking the address of a function.
1747 bool rewriteDeviceCodeStateMachine();
1752 /// Emit a remark generically
1754 /// This template function can be used to generically emit a remark. The
1755 /// RemarkKind should be one of the following:
1756 /// - OptimizationRemark to indicate a successful optimization attempt
1757 /// - OptimizationRemarkMissed to report a failed optimization attempt
1758 /// - OptimizationRemarkAnalysis to provide additional information about an
1759 /// optimization attempt
1761 /// The remark is built using a callback function provided by the caller that
1762 /// takes a RemarkKind as input and returns a RemarkKind.
1763 template <typename RemarkKind
, typename RemarkCallBack
>
1764 void emitRemark(Instruction
*I
, StringRef RemarkName
,
1765 RemarkCallBack
&&RemarkCB
) const {
1766 Function
*F
= I
->getParent()->getParent();
1767 auto &ORE
= OREGetter(F
);
1769 if (RemarkName
.startswith("OMP"))
1771 return RemarkCB(RemarkKind(DEBUG_TYPE
, RemarkName
, I
))
1772 << " [" << RemarkName
<< "]";
1776 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE
, RemarkName
, I
)); });
1779 /// Emit a remark on a function.
1780 template <typename RemarkKind
, typename RemarkCallBack
>
1781 void emitRemark(Function
*F
, StringRef RemarkName
,
1782 RemarkCallBack
&&RemarkCB
) const {
1783 auto &ORE
= OREGetter(F
);
1785 if (RemarkName
.startswith("OMP"))
1787 return RemarkCB(RemarkKind(DEBUG_TYPE
, RemarkName
, F
))
1788 << " [" << RemarkName
<< "]";
1792 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE
, RemarkName
, F
)); });
1795 /// RAII struct to temporarily change an RTL function's linkage to external.
1796 /// This prevents it from being mistakenly removed by other optimizations.
1797 struct ExternalizationRAII
{
1798 ExternalizationRAII(OMPInformationCache
&OMPInfoCache
,
1799 RuntimeFunction RFKind
)
1800 : Declaration(OMPInfoCache
.RFIs
[RFKind
].Declaration
) {
1804 LinkageType
= Declaration
->getLinkage();
1805 Declaration
->setLinkage(GlobalValue::ExternalLinkage
);
1808 ~ExternalizationRAII() {
1812 Declaration
->setLinkage(LinkageType
);
1815 Function
*Declaration
;
1816 GlobalValue::LinkageTypes LinkageType
;
1819 /// The underlying module.
1822 /// The SCC we are operating on.
1823 SmallVectorImpl
<Function
*> &SCC
;
1825 /// Callback to update the call graph, the first argument is a removed call,
1826 /// the second an optional replacement call.
1827 CallGraphUpdater
&CGUpdater
;
1829 /// Callback to get an OptimizationRemarkEmitter from a Function *
1830 OptimizationRemarkGetter OREGetter
;
1832 /// OpenMP-specific information cache. Also Used for Attributor runs.
1833 OMPInformationCache
&OMPInfoCache
;
1835 /// Attributor instance.
1838 /// Helper function to run Attributor on SCC.
1839 bool runAttributor(bool IsModulePass
) {
1843 // Temporarily make these function have external linkage so the Attributor
1844 // doesn't remove them when we try to look them up later.
1845 ExternalizationRAII
Parallel(OMPInfoCache
, OMPRTL___kmpc_kernel_parallel
);
1846 ExternalizationRAII
EndParallel(OMPInfoCache
,
1847 OMPRTL___kmpc_kernel_end_parallel
);
1848 ExternalizationRAII
BarrierSPMD(OMPInfoCache
,
1849 OMPRTL___kmpc_barrier_simple_spmd
);
1851 registerAAs(IsModulePass
);
1853 ChangeStatus Changed
= A
.run();
1855 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC
.size()
1856 << " functions, result: " << Changed
<< ".\n");
1858 return Changed
== ChangeStatus::CHANGED
;
1861 void registerFoldRuntimeCall(RuntimeFunction RF
);
1863 /// Populate the Attributor with abstract attribute opportunities in the
1865 void registerAAs(bool IsModulePass
);
1868 Kernel
OpenMPOpt::getUniqueKernelFor(Function
&F
) {
1869 if (!OMPInfoCache
.ModuleSlice
.count(&F
))
1872 // Use a scope to keep the lifetime of the CachedKernel short.
1874 Optional
<Kernel
> &CachedKernel
= UniqueKernelMap
[&F
];
1876 return *CachedKernel
;
1878 // TODO: We should use an AA to create an (optimistic and callback
1879 // call-aware) call graph. For now we stick to simple patterns that
1880 // are less powerful, basically the worst fixpoint.
1882 CachedKernel
= Kernel(&F
);
1883 return *CachedKernel
;
1886 CachedKernel
= nullptr;
1887 if (!F
.hasLocalLinkage()) {
1889 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1890 auto Remark
= [&](OptimizationRemarkAnalysis ORA
) {
1891 return ORA
<< "Potentially unknown OpenMP target region caller.";
1893 emitRemark
<OptimizationRemarkAnalysis
>(&F
, "OMP100", Remark
);
1899 auto GetUniqueKernelForUse
= [&](const Use
&U
) -> Kernel
{
1900 if (auto *Cmp
= dyn_cast
<ICmpInst
>(U
.getUser())) {
1901 // Allow use in equality comparisons.
1902 if (Cmp
->isEquality())
1903 return getUniqueKernelFor(*Cmp
);
1906 if (auto *CB
= dyn_cast
<CallBase
>(U
.getUser())) {
1907 // Allow direct calls.
1908 if (CB
->isCallee(&U
))
1909 return getUniqueKernelFor(*CB
);
1911 OMPInformationCache::RuntimeFunctionInfo
&KernelParallelRFI
=
1912 OMPInfoCache
.RFIs
[OMPRTL___kmpc_parallel_51
];
1913 // Allow the use in __kmpc_parallel_51 calls.
1914 if (OpenMPOpt::getCallIfRegularCall(*U
.getUser(), &KernelParallelRFI
))
1915 return getUniqueKernelFor(*CB
);
1918 // Disallow every other use.
1922 // TODO: In the future we want to track more than just a unique kernel.
1923 SmallPtrSet
<Kernel
, 2> PotentialKernels
;
1924 OMPInformationCache::foreachUse(F
, [&](const Use
&U
) {
1925 PotentialKernels
.insert(GetUniqueKernelForUse(U
));
1929 if (PotentialKernels
.size() == 1)
1930 K
= *PotentialKernels
.begin();
1932 // Cache the result.
1933 UniqueKernelMap
[&F
] = K
;
1938 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1939 OMPInformationCache::RuntimeFunctionInfo
&KernelParallelRFI
=
1940 OMPInfoCache
.RFIs
[OMPRTL___kmpc_parallel_51
];
1942 bool Changed
= false;
1943 if (!KernelParallelRFI
)
1946 // If we have disabled state machine changes, exit
1947 if (DisableOpenMPOptStateMachineRewrite
)
1950 for (Function
*F
: SCC
) {
1952 // Check if the function is a use in a __kmpc_parallel_51 call at
1954 bool UnknownUse
= false;
1955 bool KernelParallelUse
= false;
1956 unsigned NumDirectCalls
= 0;
1958 SmallVector
<Use
*, 2> ToBeReplacedStateMachineUses
;
1959 OMPInformationCache::foreachUse(*F
, [&](Use
&U
) {
1960 if (auto *CB
= dyn_cast
<CallBase
>(U
.getUser()))
1961 if (CB
->isCallee(&U
)) {
1966 if (isa
<ICmpInst
>(U
.getUser())) {
1967 ToBeReplacedStateMachineUses
.push_back(&U
);
1971 // Find wrapper functions that represent parallel kernels.
1973 OpenMPOpt::getCallIfRegularCall(*U
.getUser(), &KernelParallelRFI
);
1974 const unsigned int WrapperFunctionArgNo
= 6;
1975 if (!KernelParallelUse
&& CI
&&
1976 CI
->getArgOperandNo(&U
) == WrapperFunctionArgNo
) {
1977 KernelParallelUse
= true;
1978 ToBeReplacedStateMachineUses
.push_back(&U
);
1984 // Do not emit a remark if we haven't seen a __kmpc_parallel_51
1986 if (!KernelParallelUse
)
1989 // If this ever hits, we should investigate.
1990 // TODO: Checking the number of uses is not a necessary restriction and
1991 // should be lifted.
1992 if (UnknownUse
|| NumDirectCalls
!= 1 ||
1993 ToBeReplacedStateMachineUses
.size() > 2) {
1994 auto Remark
= [&](OptimizationRemarkAnalysis ORA
) {
1995 return ORA
<< "Parallel region is used in "
1996 << (UnknownUse
? "unknown" : "unexpected")
1997 << " ways. Will not attempt to rewrite the state machine.";
1999 emitRemark
<OptimizationRemarkAnalysis
>(F
, "OMP101", Remark
);
2003 // Even if we have __kmpc_parallel_51 calls, we (for now) give
2004 // up if the function is not called from a unique kernel.
2005 Kernel K
= getUniqueKernelFor(*F
);
2007 auto Remark
= [&](OptimizationRemarkAnalysis ORA
) {
2008 return ORA
<< "Parallel region is not called from a unique kernel. "
2009 "Will not attempt to rewrite the state machine.";
2011 emitRemark
<OptimizationRemarkAnalysis
>(F
, "OMP102", Remark
);
2015 // We now know F is a parallel body function called only from the kernel K.
2016 // We also identified the state machine uses in which we replace the
2017 // function pointer by a new global symbol for identification purposes. This
2018 // ensures only direct calls to the function are left.
2020 Module
&M
= *F
->getParent();
2021 Type
*Int8Ty
= Type::getInt8Ty(M
.getContext());
2023 auto *ID
= new GlobalVariable(
2024 M
, Int8Ty
, /* isConstant */ true, GlobalValue::PrivateLinkage
,
2025 UndefValue::get(Int8Ty
), F
->getName() + ".ID");
2027 for (Use
*U
: ToBeReplacedStateMachineUses
)
2028 U
->set(ConstantExpr::getBitCast(ID
, U
->get()->getType()));
2030 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine
;
2038 /// Abstract Attribute for tracking ICV values.
2039 struct AAICVTracker
: public StateWrapper
<BooleanState
, AbstractAttribute
> {
2040 using Base
= StateWrapper
<BooleanState
, AbstractAttribute
>;
2041 AAICVTracker(const IRPosition
&IRP
, Attributor
&A
) : Base(IRP
) {}
2043 void initialize(Attributor
&A
) override
{
2044 Function
*F
= getAnchorScope();
2045 if (!F
|| !A
.isFunctionIPOAmendable(*F
))
2046 indicatePessimisticFixpoint();
2049 /// Returns true if value is assumed to be tracked.
2050 bool isAssumedTracked() const { return getAssumed(); }
2052 /// Returns true if value is known to be tracked.
2053 bool isKnownTracked() const { return getAssumed(); }
2055 /// Create an abstract attribute biew for the position \p IRP.
2056 static AAICVTracker
&createForPosition(const IRPosition
&IRP
, Attributor
&A
);
2058 /// Return the value with which \p I can be replaced for specific \p ICV.
2059 virtual Optional
<Value
*> getReplacementValue(InternalControlVar ICV
,
2060 const Instruction
*I
,
2061 Attributor
&A
) const {
2065 /// Return an assumed unique ICV value if a single candidate is found. If
2066 /// there cannot be one, return a nullptr. If it is not clear yet, return the
2067 /// Optional::NoneType.
2068 virtual Optional
<Value
*>
2069 getUniqueReplacementValue(InternalControlVar ICV
) const = 0;
2071 // Currently only nthreads is being tracked.
2072 // this array will only grow with time.
2073 InternalControlVar TrackableICVs
[1] = {ICV_nthreads
};
2075 /// See AbstractAttribute::getName()
2076 const std::string
getName() const override
{ return "AAICVTracker"; }
2078 /// See AbstractAttribute::getIdAddr()
2079 const char *getIdAddr() const override
{ return &ID
; }
2081 /// This function should return true if the type of the \p AA is AAICVTracker
2082 static bool classof(const AbstractAttribute
*AA
) {
2083 return (AA
->getIdAddr() == &ID
);
2086 static const char ID
;
2089 struct AAICVTrackerFunction
: public AAICVTracker
{
2090 AAICVTrackerFunction(const IRPosition
&IRP
, Attributor
&A
)
2091 : AAICVTracker(IRP
, A
) {}
2093 // FIXME: come up with better string.
2094 const std::string
getAsStr() const override
{ return "ICVTrackerFunction"; }
2096 // FIXME: come up with some stats.
2097 void trackStatistics() const override
{}
2099 /// We don't manifest anything for this AA.
2100 ChangeStatus
manifest(Attributor
&A
) override
{
2101 return ChangeStatus::UNCHANGED
;
2104 // Map of ICV to their values at specific program point.
2105 EnumeratedArray
<DenseMap
<Instruction
*, Value
*>, InternalControlVar
,
2106 InternalControlVar::ICV___last
>
2107 ICVReplacementValuesMap
;
2109 ChangeStatus
updateImpl(Attributor
&A
) override
{
2110 ChangeStatus HasChanged
= ChangeStatus::UNCHANGED
;
2112 Function
*F
= getAnchorScope();
2114 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
2116 for (InternalControlVar ICV
: TrackableICVs
) {
2117 auto &SetterRFI
= OMPInfoCache
.RFIs
[OMPInfoCache
.ICVs
[ICV
].Setter
];
2119 auto &ValuesMap
= ICVReplacementValuesMap
[ICV
];
2120 auto TrackValues
= [&](Use
&U
, Function
&) {
2121 CallInst
*CI
= OpenMPOpt::getCallIfRegularCall(U
);
2125 // FIXME: handle setters with more that 1 arguments.
2126 /// Track new value.
2127 if (ValuesMap
.insert(std::make_pair(CI
, CI
->getArgOperand(0))).second
)
2128 HasChanged
= ChangeStatus::CHANGED
;
2133 auto CallCheck
= [&](Instruction
&I
) {
2134 Optional
<Value
*> ReplVal
= getValueForCall(A
, &I
, ICV
);
2135 if (ReplVal
.hasValue() &&
2136 ValuesMap
.insert(std::make_pair(&I
, *ReplVal
)).second
)
2137 HasChanged
= ChangeStatus::CHANGED
;
2142 // Track all changes of an ICV.
2143 SetterRFI
.foreachUse(TrackValues
, F
);
2145 bool UsedAssumedInformation
= false;
2146 A
.checkForAllInstructions(CallCheck
, *this, {Instruction::Call
},
2147 UsedAssumedInformation
,
2148 /* CheckBBLivenessOnly */ true);
2150 /// TODO: Figure out a way to avoid adding entry in
2151 /// ICVReplacementValuesMap
2152 Instruction
*Entry
= &F
->getEntryBlock().front();
2153 if (HasChanged
== ChangeStatus::CHANGED
&& !ValuesMap
.count(Entry
))
2154 ValuesMap
.insert(std::make_pair(Entry
, nullptr));
2160 /// Hepler to check if \p I is a call and get the value for it if it is
2162 Optional
<Value
*> getValueForCall(Attributor
&A
, const Instruction
*I
,
2163 InternalControlVar
&ICV
) const {
2165 const auto *CB
= dyn_cast
<CallBase
>(I
);
2166 if (!CB
|| CB
->hasFnAttr("no_openmp") ||
2167 CB
->hasFnAttr("no_openmp_routines"))
2170 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
2171 auto &GetterRFI
= OMPInfoCache
.RFIs
[OMPInfoCache
.ICVs
[ICV
].Getter
];
2172 auto &SetterRFI
= OMPInfoCache
.RFIs
[OMPInfoCache
.ICVs
[ICV
].Setter
];
2173 Function
*CalledFunction
= CB
->getCalledFunction();
2175 // Indirect call, assume ICV changes.
2176 if (CalledFunction
== nullptr)
2178 if (CalledFunction
== GetterRFI
.Declaration
)
2180 if (CalledFunction
== SetterRFI
.Declaration
) {
2181 if (ICVReplacementValuesMap
[ICV
].count(I
))
2182 return ICVReplacementValuesMap
[ICV
].lookup(I
);
2187 // Since we don't know, assume it changes the ICV.
2188 if (CalledFunction
->isDeclaration())
2191 const auto &ICVTrackingAA
= A
.getAAFor
<AAICVTracker
>(
2192 *this, IRPosition::callsite_returned(*CB
), DepClassTy::REQUIRED
);
2194 if (ICVTrackingAA
.isAssumedTracked())
2195 return ICVTrackingAA
.getUniqueReplacementValue(ICV
);
2197 // If we don't know, assume it changes.
2201 // We don't check unique value for a function, so return None.
2203 getUniqueReplacementValue(InternalControlVar ICV
) const override
{
2207 /// Return the value with which \p I can be replaced for specific \p ICV.
2208 Optional
<Value
*> getReplacementValue(InternalControlVar ICV
,
2209 const Instruction
*I
,
2210 Attributor
&A
) const override
{
2211 const auto &ValuesMap
= ICVReplacementValuesMap
[ICV
];
2212 if (ValuesMap
.count(I
))
2213 return ValuesMap
.lookup(I
);
2215 SmallVector
<const Instruction
*, 16> Worklist
;
2216 SmallPtrSet
<const Instruction
*, 16> Visited
;
2217 Worklist
.push_back(I
);
2219 Optional
<Value
*> ReplVal
;
2221 while (!Worklist
.empty()) {
2222 const Instruction
*CurrInst
= Worklist
.pop_back_val();
2223 if (!Visited
.insert(CurrInst
).second
)
2226 const BasicBlock
*CurrBB
= CurrInst
->getParent();
2228 // Go up and look for all potential setters/calls that might change the
2230 while ((CurrInst
= CurrInst
->getPrevNode())) {
2231 if (ValuesMap
.count(CurrInst
)) {
2232 Optional
<Value
*> NewReplVal
= ValuesMap
.lookup(CurrInst
);
2233 // Unknown value, track new.
2234 if (!ReplVal
.hasValue()) {
2235 ReplVal
= NewReplVal
;
2239 // If we found a new value, we can't know the icv value anymore.
2240 if (NewReplVal
.hasValue())
2241 if (ReplVal
!= NewReplVal
)
2247 Optional
<Value
*> NewReplVal
= getValueForCall(A
, CurrInst
, ICV
);
2248 if (!NewReplVal
.hasValue())
2251 // Unknown value, track new.
2252 if (!ReplVal
.hasValue()) {
2253 ReplVal
= NewReplVal
;
2257 // if (NewReplVal.hasValue())
2258 // We found a new value, we can't know the icv value anymore.
2259 if (ReplVal
!= NewReplVal
)
2263 // If we are in the same BB and we have a value, we are done.
2264 if (CurrBB
== I
->getParent() && ReplVal
.hasValue())
2267 // Go through all predecessors and add terminators for analysis.
2268 for (const BasicBlock
*Pred
: predecessors(CurrBB
))
2269 if (const Instruction
*Terminator
= Pred
->getTerminator())
2270 Worklist
.push_back(Terminator
);
2277 struct AAICVTrackerFunctionReturned
: AAICVTracker
{
2278 AAICVTrackerFunctionReturned(const IRPosition
&IRP
, Attributor
&A
)
2279 : AAICVTracker(IRP
, A
) {}
2281 // FIXME: come up with better string.
2282 const std::string
getAsStr() const override
{
2283 return "ICVTrackerFunctionReturned";
2286 // FIXME: come up with some stats.
2287 void trackStatistics() const override
{}
2289 /// We don't manifest anything for this AA.
2290 ChangeStatus
manifest(Attributor
&A
) override
{
2291 return ChangeStatus::UNCHANGED
;
2294 // Map of ICV to their values at specific program point.
2295 EnumeratedArray
<Optional
<Value
*>, InternalControlVar
,
2296 InternalControlVar::ICV___last
>
2297 ICVReplacementValuesMap
;
2299 /// Return the value with which \p I can be replaced for specific \p ICV.
2301 getUniqueReplacementValue(InternalControlVar ICV
) const override
{
2302 return ICVReplacementValuesMap
[ICV
];
2305 ChangeStatus
updateImpl(Attributor
&A
) override
{
2306 ChangeStatus Changed
= ChangeStatus::UNCHANGED
;
2307 const auto &ICVTrackingAA
= A
.getAAFor
<AAICVTracker
>(
2308 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED
);
2310 if (!ICVTrackingAA
.isAssumedTracked())
2311 return indicatePessimisticFixpoint();
2313 for (InternalControlVar ICV
: TrackableICVs
) {
2314 Optional
<Value
*> &ReplVal
= ICVReplacementValuesMap
[ICV
];
2315 Optional
<Value
*> UniqueICVValue
;
2317 auto CheckReturnInst
= [&](Instruction
&I
) {
2318 Optional
<Value
*> NewReplVal
=
2319 ICVTrackingAA
.getReplacementValue(ICV
, &I
, A
);
2321 // If we found a second ICV value there is no unique returned value.
2322 if (UniqueICVValue
.hasValue() && UniqueICVValue
!= NewReplVal
)
2325 UniqueICVValue
= NewReplVal
;
2330 bool UsedAssumedInformation
= false;
2331 if (!A
.checkForAllInstructions(CheckReturnInst
, *this, {Instruction::Ret
},
2332 UsedAssumedInformation
,
2333 /* CheckBBLivenessOnly */ true))
2334 UniqueICVValue
= nullptr;
2336 if (UniqueICVValue
== ReplVal
)
2339 ReplVal
= UniqueICVValue
;
2340 Changed
= ChangeStatus::CHANGED
;
2347 struct AAICVTrackerCallSite
: AAICVTracker
{
2348 AAICVTrackerCallSite(const IRPosition
&IRP
, Attributor
&A
)
2349 : AAICVTracker(IRP
, A
) {}
2351 void initialize(Attributor
&A
) override
{
2352 Function
*F
= getAnchorScope();
2353 if (!F
|| !A
.isFunctionIPOAmendable(*F
))
2354 indicatePessimisticFixpoint();
2356 // We only initialize this AA for getters, so we need to know which ICV it
2358 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
2359 for (InternalControlVar ICV
: TrackableICVs
) {
2360 auto ICVInfo
= OMPInfoCache
.ICVs
[ICV
];
2361 auto &Getter
= OMPInfoCache
.RFIs
[ICVInfo
.Getter
];
2362 if (Getter
.Declaration
== getAssociatedFunction()) {
2363 AssociatedICV
= ICVInfo
.Kind
;
2369 indicatePessimisticFixpoint();
2372 ChangeStatus
manifest(Attributor
&A
) override
{
2373 if (!ReplVal
.hasValue() || !ReplVal
.getValue())
2374 return ChangeStatus::UNCHANGED
;
2376 A
.changeValueAfterManifest(*getCtxI(), **ReplVal
);
2377 A
.deleteAfterManifest(*getCtxI());
2379 return ChangeStatus::CHANGED
;
2382 // FIXME: come up with better string.
2383 const std::string
getAsStr() const override
{ return "ICVTrackerCallSite"; }
2385 // FIXME: come up with some stats.
2386 void trackStatistics() const override
{}
2388 InternalControlVar AssociatedICV
;
2389 Optional
<Value
*> ReplVal
;
2391 ChangeStatus
updateImpl(Attributor
&A
) override
{
2392 const auto &ICVTrackingAA
= A
.getAAFor
<AAICVTracker
>(
2393 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED
);
2395 // We don't have any information, so we assume it changes the ICV.
2396 if (!ICVTrackingAA
.isAssumedTracked())
2397 return indicatePessimisticFixpoint();
2399 Optional
<Value
*> NewReplVal
=
2400 ICVTrackingAA
.getReplacementValue(AssociatedICV
, getCtxI(), A
);
2402 if (ReplVal
== NewReplVal
)
2403 return ChangeStatus::UNCHANGED
;
2405 ReplVal
= NewReplVal
;
2406 return ChangeStatus::CHANGED
;
2409 // Return the value with which associated value can be replaced for specific
2412 getUniqueReplacementValue(InternalControlVar ICV
) const override
{
2417 struct AAICVTrackerCallSiteReturned
: AAICVTracker
{
2418 AAICVTrackerCallSiteReturned(const IRPosition
&IRP
, Attributor
&A
)
2419 : AAICVTracker(IRP
, A
) {}
2421 // FIXME: come up with better string.
2422 const std::string
getAsStr() const override
{
2423 return "ICVTrackerCallSiteReturned";
2426 // FIXME: come up with some stats.
2427 void trackStatistics() const override
{}
2429 /// We don't manifest anything for this AA.
2430 ChangeStatus
manifest(Attributor
&A
) override
{
2431 return ChangeStatus::UNCHANGED
;
2434 // Map of ICV to their values at specific program point.
2435 EnumeratedArray
<Optional
<Value
*>, InternalControlVar
,
2436 InternalControlVar::ICV___last
>
2437 ICVReplacementValuesMap
;
2439 /// Return the value with which associated value can be replaced for specific
2442 getUniqueReplacementValue(InternalControlVar ICV
) const override
{
2443 return ICVReplacementValuesMap
[ICV
];
2446 ChangeStatus
updateImpl(Attributor
&A
) override
{
2447 ChangeStatus Changed
= ChangeStatus::UNCHANGED
;
2448 const auto &ICVTrackingAA
= A
.getAAFor
<AAICVTracker
>(
2449 *this, IRPosition::returned(*getAssociatedFunction()),
2450 DepClassTy::REQUIRED
);
2452 // We don't have any information, so we assume it changes the ICV.
2453 if (!ICVTrackingAA
.isAssumedTracked())
2454 return indicatePessimisticFixpoint();
2456 for (InternalControlVar ICV
: TrackableICVs
) {
2457 Optional
<Value
*> &ReplVal
= ICVReplacementValuesMap
[ICV
];
2458 Optional
<Value
*> NewReplVal
=
2459 ICVTrackingAA
.getUniqueReplacementValue(ICV
);
2461 if (ReplVal
== NewReplVal
)
2464 ReplVal
= NewReplVal
;
2465 Changed
= ChangeStatus::CHANGED
;
2471 struct AAExecutionDomainFunction
: public AAExecutionDomain
{
2472 AAExecutionDomainFunction(const IRPosition
&IRP
, Attributor
&A
)
2473 : AAExecutionDomain(IRP
, A
) {}
2475 const std::string
getAsStr() const override
{
2476 return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs
.size()) +
2477 "/" + std::to_string(NumBBs
) + " BBs thread 0 only.";
2480 /// See AbstractAttribute::trackStatistics().
2481 void trackStatistics() const override
{}
2483 void initialize(Attributor
&A
) override
{
2484 Function
*F
= getAnchorScope();
2485 for (const auto &BB
: *F
)
2486 SingleThreadedBBs
.insert(&BB
);
2487 NumBBs
= SingleThreadedBBs
.size();
2490 ChangeStatus
manifest(Attributor
&A
) override
{
2492 for (const BasicBlock
*BB
: SingleThreadedBBs
)
2493 dbgs() << TAG
<< " Basic block @" << getAnchorScope()->getName() << " "
2494 << BB
->getName() << " is executed by a single thread.\n";
2496 return ChangeStatus::UNCHANGED
;
2499 ChangeStatus
updateImpl(Attributor
&A
) override
;
2501 /// Check if an instruction is executed by a single thread.
2502 bool isExecutedByInitialThreadOnly(const Instruction
&I
) const override
{
2503 return isExecutedByInitialThreadOnly(*I
.getParent());
2506 bool isExecutedByInitialThreadOnly(const BasicBlock
&BB
) const override
{
2507 return isValidState() && SingleThreadedBBs
.contains(&BB
);
2510 /// Set of basic blocks that are executed by a single thread.
2511 DenseSet
<const BasicBlock
*> SingleThreadedBBs
;
2513 /// Total number of basic blocks in this function.
2514 long unsigned NumBBs
;
2517 ChangeStatus
AAExecutionDomainFunction::updateImpl(Attributor
&A
) {
2518 Function
*F
= getAnchorScope();
2519 ReversePostOrderTraversal
<Function
*> RPOT(F
);
2520 auto NumSingleThreadedBBs
= SingleThreadedBBs
.size();
2522 bool AllCallSitesKnown
;
2523 auto PredForCallSite
= [&](AbstractCallSite ACS
) {
2524 const auto &ExecutionDomainAA
= A
.getAAFor
<AAExecutionDomain
>(
2525 *this, IRPosition::function(*ACS
.getInstruction()->getFunction()),
2526 DepClassTy::REQUIRED
);
2527 return ACS
.isDirectCall() &&
2528 ExecutionDomainAA
.isExecutedByInitialThreadOnly(
2529 *ACS
.getInstruction());
2532 if (!A
.checkForAllCallSites(PredForCallSite
, *this,
2533 /* RequiresAllCallSites */ true,
2535 SingleThreadedBBs
.erase(&F
->getEntryBlock());
2537 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
2538 auto &RFI
= OMPInfoCache
.RFIs
[OMPRTL___kmpc_target_init
];
2540 // Check if the edge into the successor block compares the __kmpc_target_init
2541 // result with -1. If we are in non-SPMD-mode that signals only the main
2542 // thread will execute the edge.
2543 auto IsInitialThreadOnly
= [&](BranchInst
*Edge
, BasicBlock
*SuccessorBB
) {
2544 if (!Edge
|| !Edge
->isConditional())
2546 if (Edge
->getSuccessor(0) != SuccessorBB
)
2549 auto *Cmp
= dyn_cast
<CmpInst
>(Edge
->getCondition());
2550 if (!Cmp
|| !Cmp
->isTrueWhenEqual() || !Cmp
->isEquality())
2553 ConstantInt
*C
= dyn_cast
<ConstantInt
>(Cmp
->getOperand(1));
2557 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2558 if (C
->isAllOnesValue()) {
2559 auto *CB
= dyn_cast
<CallBase
>(Cmp
->getOperand(0));
2560 CB
= CB
? OpenMPOpt::getCallIfRegularCall(*CB
, &RFI
) : nullptr;
2563 const int InitIsSPMDArgNo
= 1;
2564 auto *IsSPMDModeCI
=
2565 dyn_cast
<ConstantInt
>(CB
->getOperand(InitIsSPMDArgNo
));
2566 return IsSPMDModeCI
&& IsSPMDModeCI
->isZero();
2572 // Merge all the predecessor states into the current basic block. A basic
2573 // block is executed by a single thread if all of its predecessors are.
2574 auto MergePredecessorStates
= [&](BasicBlock
*BB
) {
2575 if (pred_begin(BB
) == pred_end(BB
))
2576 return SingleThreadedBBs
.contains(BB
);
2578 bool IsInitialThread
= true;
2579 for (auto PredBB
= pred_begin(BB
), PredEndBB
= pred_end(BB
);
2580 PredBB
!= PredEndBB
; ++PredBB
) {
2581 if (!IsInitialThreadOnly(dyn_cast
<BranchInst
>((*PredBB
)->getTerminator()),
2583 IsInitialThread
&= SingleThreadedBBs
.contains(*PredBB
);
2586 return IsInitialThread
;
2589 for (auto *BB
: RPOT
) {
2590 if (!MergePredecessorStates(BB
))
2591 SingleThreadedBBs
.erase(BB
);
2594 return (NumSingleThreadedBBs
== SingleThreadedBBs
.size())
2595 ? ChangeStatus::UNCHANGED
2596 : ChangeStatus::CHANGED
;
2599 /// Try to replace memory allocation calls called by a single thread with a
2600 /// static buffer of shared memory.
2601 struct AAHeapToShared
: public StateWrapper
<BooleanState
, AbstractAttribute
> {
2602 using Base
= StateWrapper
<BooleanState
, AbstractAttribute
>;
2603 AAHeapToShared(const IRPosition
&IRP
, Attributor
&A
) : Base(IRP
) {}
2605 /// Create an abstract attribute view for the position \p IRP.
2606 static AAHeapToShared
&createForPosition(const IRPosition
&IRP
,
2609 /// Returns true if HeapToShared conversion is assumed to be possible.
2610 virtual bool isAssumedHeapToShared(CallBase
&CB
) const = 0;
2612 /// Returns true if HeapToShared conversion is assumed and the CB is a
2613 /// callsite to a free operation to be removed.
2614 virtual bool isAssumedHeapToSharedRemovedFree(CallBase
&CB
) const = 0;
2616 /// See AbstractAttribute::getName().
2617 const std::string
getName() const override
{ return "AAHeapToShared"; }
2619 /// See AbstractAttribute::getIdAddr().
2620 const char *getIdAddr() const override
{ return &ID
; }
2622 /// This function should return true if the type of the \p AA is
2624 static bool classof(const AbstractAttribute
*AA
) {
2625 return (AA
->getIdAddr() == &ID
);
2628 /// Unique ID (due to the unique address)
2629 static const char ID
;
2632 struct AAHeapToSharedFunction
: public AAHeapToShared
{
2633 AAHeapToSharedFunction(const IRPosition
&IRP
, Attributor
&A
)
2634 : AAHeapToShared(IRP
, A
) {}
2636 const std::string
getAsStr() const override
{
2637 return "[AAHeapToShared] " + std::to_string(MallocCalls
.size()) +
2638 " malloc calls eligible.";
2641 /// See AbstractAttribute::trackStatistics().
2642 void trackStatistics() const override
{}
2644 /// This functions finds free calls that will be removed by the
2645 /// HeapToShared transformation.
2646 void findPotentialRemovedFreeCalls(Attributor
&A
) {
2647 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
2648 auto &FreeRFI
= OMPInfoCache
.RFIs
[OMPRTL___kmpc_free_shared
];
2650 PotentialRemovedFreeCalls
.clear();
2651 // Update free call users of found malloc calls.
2652 for (CallBase
*CB
: MallocCalls
) {
2653 SmallVector
<CallBase
*, 4> FreeCalls
;
2654 for (auto *U
: CB
->users()) {
2655 CallBase
*C
= dyn_cast
<CallBase
>(U
);
2656 if (C
&& C
->getCalledFunction() == FreeRFI
.Declaration
)
2657 FreeCalls
.push_back(C
);
2660 if (FreeCalls
.size() != 1)
2663 PotentialRemovedFreeCalls
.insert(FreeCalls
.front());
2667 void initialize(Attributor
&A
) override
{
2668 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
2669 auto &RFI
= OMPInfoCache
.RFIs
[OMPRTL___kmpc_alloc_shared
];
2671 for (User
*U
: RFI
.Declaration
->users())
2672 if (CallBase
*CB
= dyn_cast
<CallBase
>(U
))
2673 MallocCalls
.insert(CB
);
2675 findPotentialRemovedFreeCalls(A
);
2678 bool isAssumedHeapToShared(CallBase
&CB
) const override
{
2679 return isValidState() && MallocCalls
.count(&CB
);
2682 bool isAssumedHeapToSharedRemovedFree(CallBase
&CB
) const override
{
2683 return isValidState() && PotentialRemovedFreeCalls
.count(&CB
);
2686 ChangeStatus
manifest(Attributor
&A
) override
{
2687 if (MallocCalls
.empty())
2688 return ChangeStatus::UNCHANGED
;
2690 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
2691 auto &FreeCall
= OMPInfoCache
.RFIs
[OMPRTL___kmpc_free_shared
];
2693 Function
*F
= getAnchorScope();
2694 auto *HS
= A
.lookupAAFor
<AAHeapToStack
>(IRPosition::function(*F
), this,
2695 DepClassTy::OPTIONAL
);
2697 ChangeStatus Changed
= ChangeStatus::UNCHANGED
;
2698 for (CallBase
*CB
: MallocCalls
) {
2699 // Skip replacing this if HeapToStack has already claimed it.
2700 if (HS
&& HS
->isAssumedHeapToStack(*CB
))
2703 // Find the unique free call to remove it.
2704 SmallVector
<CallBase
*, 4> FreeCalls
;
2705 for (auto *U
: CB
->users()) {
2706 CallBase
*C
= dyn_cast
<CallBase
>(U
);
2707 if (C
&& C
->getCalledFunction() == FreeCall
.Declaration
)
2708 FreeCalls
.push_back(C
);
2710 if (FreeCalls
.size() != 1)
2713 ConstantInt
*AllocSize
= dyn_cast
<ConstantInt
>(CB
->getArgOperand(0));
2715 LLVM_DEBUG(dbgs() << TAG
<< "Replace globalization call " << *CB
2716 << " with " << AllocSize
->getZExtValue()
2717 << " bytes of shared memory\n");
2719 // Create a new shared memory buffer of the same size as the allocation
2720 // and replace all the uses of the original allocation with it.
2721 Module
*M
= CB
->getModule();
2722 Type
*Int8Ty
= Type::getInt8Ty(M
->getContext());
2723 Type
*Int8ArrTy
= ArrayType::get(Int8Ty
, AllocSize
->getZExtValue());
2724 auto *SharedMem
= new GlobalVariable(
2725 *M
, Int8ArrTy
, /* IsConstant */ false, GlobalValue::InternalLinkage
,
2726 UndefValue::get(Int8ArrTy
), CB
->getName(), nullptr,
2727 GlobalValue::NotThreadLocal
,
2728 static_cast<unsigned>(AddressSpace::Shared
));
2730 ConstantExpr::getPointerCast(SharedMem
, Int8Ty
->getPointerTo());
2732 auto Remark
= [&](OptimizationRemark OR
) {
2733 return OR
<< "Replaced globalized variable with "
2734 << ore::NV("SharedMemory", AllocSize
->getZExtValue())
2735 << ((AllocSize
->getZExtValue() != 1) ? " bytes " : " byte ")
2736 << "of shared memory.";
2738 A
.emitRemark
<OptimizationRemark
>(CB
, "OMP111", Remark
);
2740 SharedMem
->setAlignment(MaybeAlign(32));
2742 A
.changeValueAfterManifest(*CB
, *NewBuffer
);
2743 A
.deleteAfterManifest(*CB
);
2744 A
.deleteAfterManifest(*FreeCalls
.front());
2746 NumBytesMovedToSharedMemory
+= AllocSize
->getZExtValue();
2747 Changed
= ChangeStatus::CHANGED
;
2753 ChangeStatus
updateImpl(Attributor
&A
) override
{
2754 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
2755 auto &RFI
= OMPInfoCache
.RFIs
[OMPRTL___kmpc_alloc_shared
];
2756 Function
*F
= getAnchorScope();
2758 auto NumMallocCalls
= MallocCalls
.size();
2760 // Only consider malloc calls executed by a single thread with a constant.
2761 for (User
*U
: RFI
.Declaration
->users()) {
2762 const auto &ED
= A
.getAAFor
<AAExecutionDomain
>(
2763 *this, IRPosition::function(*F
), DepClassTy::REQUIRED
);
2764 if (CallBase
*CB
= dyn_cast
<CallBase
>(U
))
2765 if (!dyn_cast
<ConstantInt
>(CB
->getArgOperand(0)) ||
2766 !ED
.isExecutedByInitialThreadOnly(*CB
))
2767 MallocCalls
.erase(CB
);
2770 findPotentialRemovedFreeCalls(A
);
2772 if (NumMallocCalls
!= MallocCalls
.size())
2773 return ChangeStatus::CHANGED
;
2775 return ChangeStatus::UNCHANGED
;
2778 /// Collection of all malloc calls in a function.
2779 SmallPtrSet
<CallBase
*, 4> MallocCalls
;
2780 /// Collection of potentially removed free calls in a function.
2781 SmallPtrSet
<CallBase
*, 4> PotentialRemovedFreeCalls
;
2784 struct AAKernelInfo
: public StateWrapper
<KernelInfoState
, AbstractAttribute
> {
2785 using Base
= StateWrapper
<KernelInfoState
, AbstractAttribute
>;
2786 AAKernelInfo(const IRPosition
&IRP
, Attributor
&A
) : Base(IRP
) {}
2788 /// Statistics are tracked as part of manifest for now.
2789 void trackStatistics() const override
{}
2791 /// See AbstractAttribute::getAsStr()
2792 const std::string
getAsStr() const override
{
2793 if (!isValidState())
2795 return std::string(SPMDCompatibilityTracker
.isAssumed() ? "SPMD"
2797 std::string(SPMDCompatibilityTracker
.isAtFixpoint() ? " [FIX]"
2799 std::string(" #PRs: ") +
2800 std::to_string(ReachedKnownParallelRegions
.size()) +
2801 ", #Unknown PRs: " +
2802 std::to_string(ReachedUnknownParallelRegions
.size());
2805 /// Create an abstract attribute biew for the position \p IRP.
2806 static AAKernelInfo
&createForPosition(const IRPosition
&IRP
, Attributor
&A
);
2808 /// See AbstractAttribute::getName()
2809 const std::string
getName() const override
{ return "AAKernelInfo"; }
2811 /// See AbstractAttribute::getIdAddr()
2812 const char *getIdAddr() const override
{ return &ID
; }
2814 /// This function should return true if the type of the \p AA is AAKernelInfo
2815 static bool classof(const AbstractAttribute
*AA
) {
2816 return (AA
->getIdAddr() == &ID
);
2819 static const char ID
;
2822 /// The function kernel info abstract attribute, basically, what can we say
2823 /// about a function with regards to the KernelInfoState.
2824 struct AAKernelInfoFunction
: AAKernelInfo
{
2825 AAKernelInfoFunction(const IRPosition
&IRP
, Attributor
&A
)
2826 : AAKernelInfo(IRP
, A
) {}
2828 SmallPtrSet
<Instruction
*, 4> GuardedInstructions
;
2830 SmallPtrSetImpl
<Instruction
*> &getGuardedInstructions() {
2831 return GuardedInstructions
;
2834 /// See AbstractAttribute::initialize(...).
2835 void initialize(Attributor
&A
) override
{
2836 // This is a high-level transform that might change the constant arguments
2837 // of the init and dinit calls. We need to tell the Attributor about this
2838 // to avoid other parts using the current constant value for simpliication.
2839 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
2841 Function
*Fn
= getAnchorScope();
2842 if (!OMPInfoCache
.Kernels
.count(Fn
))
2845 // Add itself to the reaching kernel and set IsKernelEntry.
2846 ReachingKernelEntries
.insert(Fn
);
2847 IsKernelEntry
= true;
2849 OMPInformationCache::RuntimeFunctionInfo
&InitRFI
=
2850 OMPInfoCache
.RFIs
[OMPRTL___kmpc_target_init
];
2851 OMPInformationCache::RuntimeFunctionInfo
&DeinitRFI
=
2852 OMPInfoCache
.RFIs
[OMPRTL___kmpc_target_deinit
];
2854 // For kernels we perform more initialization work, first we find the init
2855 // and deinit calls.
2856 auto StoreCallBase
= [](Use
&U
,
2857 OMPInformationCache::RuntimeFunctionInfo
&RFI
,
2858 CallBase
*&Storage
) {
2859 CallBase
*CB
= OpenMPOpt::getCallIfRegularCall(U
, &RFI
);
2861 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
2863 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
2868 [&](Use
&U
, Function
&) {
2869 StoreCallBase(U
, InitRFI
, KernelInitCB
);
2873 DeinitRFI
.foreachUse(
2874 [&](Use
&U
, Function
&) {
2875 StoreCallBase(U
, DeinitRFI
, KernelDeinitCB
);
2880 // Ignore kernels without initializers such as global constructors.
2881 if (!KernelInitCB
|| !KernelDeinitCB
) {
2882 indicateOptimisticFixpoint();
2886 // For kernels we might need to initialize/finalize the IsSPMD state and
2887 // we need to register a simplification callback so that the Attributor
2888 // knows the constant arguments to __kmpc_target_init and
2889 // __kmpc_target_deinit might actually change.
2891 Attributor::SimplifictionCallbackTy StateMachineSimplifyCB
=
2892 [&](const IRPosition
&IRP
, const AbstractAttribute
*AA
,
2893 bool &UsedAssumedInformation
) -> Optional
<Value
*> {
2894 // IRP represents the "use generic state machine" argument of an
2895 // __kmpc_target_init call. We will answer this one with the internal
2896 // state. As long as we are not in an invalid state, we will create a
2897 // custom state machine so the value should be a `i1 false`. If we are
2898 // in an invalid state, we won't change the value that is in the IR.
2899 if (!isValidState())
2901 // If we have disabled state machine rewrites, don't make a custom one.
2902 if (DisableOpenMPOptStateMachineRewrite
)
2905 A
.recordDependence(*this, *AA
, DepClassTy::OPTIONAL
);
2906 UsedAssumedInformation
= !isAtFixpoint();
2908 ConstantInt::getBool(IRP
.getAnchorValue().getContext(), 0);
2912 Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB
=
2913 [&](const IRPosition
&IRP
, const AbstractAttribute
*AA
,
2914 bool &UsedAssumedInformation
) -> Optional
<Value
*> {
2915 // IRP represents the "SPMDCompatibilityTracker" argument of an
2916 // __kmpc_target_init or
2917 // __kmpc_target_deinit call. We will answer this one with the internal
2919 if (!SPMDCompatibilityTracker
.isValidState())
2921 if (!SPMDCompatibilityTracker
.isAtFixpoint()) {
2923 A
.recordDependence(*this, *AA
, DepClassTy::OPTIONAL
);
2924 UsedAssumedInformation
= true;
2926 UsedAssumedInformation
= false;
2928 auto *Val
= ConstantInt::getBool(IRP
.getAnchorValue().getContext(),
2929 SPMDCompatibilityTracker
.isAssumed());
2933 Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB
=
2934 [&](const IRPosition
&IRP
, const AbstractAttribute
*AA
,
2935 bool &UsedAssumedInformation
) -> Optional
<Value
*> {
2936 // IRP represents the "RequiresFullRuntime" argument of an
2937 // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
2938 // one with the internal state of the SPMDCompatibilityTracker, so if
2939 // generic then true, if SPMD then false.
2940 if (!SPMDCompatibilityTracker
.isValidState())
2942 if (!SPMDCompatibilityTracker
.isAtFixpoint()) {
2944 A
.recordDependence(*this, *AA
, DepClassTy::OPTIONAL
);
2945 UsedAssumedInformation
= true;
2947 UsedAssumedInformation
= false;
2949 auto *Val
= ConstantInt::getBool(IRP
.getAnchorValue().getContext(),
2950 !SPMDCompatibilityTracker
.isAssumed());
2954 constexpr const int InitIsSPMDArgNo
= 1;
2955 constexpr const int DeinitIsSPMDArgNo
= 1;
2956 constexpr const int InitUseStateMachineArgNo
= 2;
2957 constexpr const int InitRequiresFullRuntimeArgNo
= 3;
2958 constexpr const int DeinitRequiresFullRuntimeArgNo
= 2;
2959 A
.registerSimplificationCallback(
2960 IRPosition::callsite_argument(*KernelInitCB
, InitUseStateMachineArgNo
),
2961 StateMachineSimplifyCB
);
2962 A
.registerSimplificationCallback(
2963 IRPosition::callsite_argument(*KernelInitCB
, InitIsSPMDArgNo
),
2964 IsSPMDModeSimplifyCB
);
2965 A
.registerSimplificationCallback(
2966 IRPosition::callsite_argument(*KernelDeinitCB
, DeinitIsSPMDArgNo
),
2967 IsSPMDModeSimplifyCB
);
2968 A
.registerSimplificationCallback(
2969 IRPosition::callsite_argument(*KernelInitCB
,
2970 InitRequiresFullRuntimeArgNo
),
2971 IsGenericModeSimplifyCB
);
2972 A
.registerSimplificationCallback(
2973 IRPosition::callsite_argument(*KernelDeinitCB
,
2974 DeinitRequiresFullRuntimeArgNo
),
2975 IsGenericModeSimplifyCB
);
2977 // Check if we know we are in SPMD-mode already.
2978 ConstantInt
*IsSPMDArg
=
2979 dyn_cast
<ConstantInt
>(KernelInitCB
->getArgOperand(InitIsSPMDArgNo
));
2980 if (IsSPMDArg
&& !IsSPMDArg
->isZero())
2981 SPMDCompatibilityTracker
.indicateOptimisticFixpoint();
2982 // This is a generic region but SPMDization is disabled so stop tracking.
2983 else if (DisableOpenMPOptSPMDization
)
2984 SPMDCompatibilityTracker
.indicatePessimisticFixpoint();
2987 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
2989 ChangeStatus
manifest(Attributor
&A
) override
{
2990 // If we are not looking at a kernel with __kmpc_target_init and
2991 // __kmpc_target_deinit call we cannot actually manifest the information.
2992 if (!KernelInitCB
|| !KernelDeinitCB
)
2993 return ChangeStatus::UNCHANGED
;
2995 // Known SPMD-mode kernels need no manifest changes.
2996 if (SPMDCompatibilityTracker
.isKnown())
2997 return ChangeStatus::UNCHANGED
;
2999 // If we can we change the execution mode to SPMD-mode otherwise we build a
3000 // custom state machine.
3001 if (!changeToSPMDMode(A
))
3002 buildCustomStateMachine(A
);
3004 return ChangeStatus::CHANGED
;
3007 bool changeToSPMDMode(Attributor
&A
) {
3008 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
3010 if (!SPMDCompatibilityTracker
.isAssumed()) {
3011 for (Instruction
*NonCompatibleI
: SPMDCompatibilityTracker
) {
3012 if (!NonCompatibleI
)
3015 // Skip diagnostics on calls to known OpenMP runtime functions for now.
3016 if (auto *CB
= dyn_cast
<CallBase
>(NonCompatibleI
))
3017 if (OMPInfoCache
.RTLFunctions
.contains(CB
->getCalledFunction()))
3020 auto Remark
= [&](OptimizationRemarkAnalysis ORA
) {
3021 ORA
<< "Value has potential side effects preventing SPMD-mode "
3023 if (isa
<CallBase
>(NonCompatibleI
)) {
3024 ORA
<< ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3025 "the called function to override";
3029 A
.emitRemark
<OptimizationRemarkAnalysis
>(NonCompatibleI
, "OMP121",
3032 LLVM_DEBUG(dbgs() << TAG
<< "SPMD-incompatible side-effect: "
3033 << *NonCompatibleI
<< "\n");
3039 auto CreateGuardedRegion
= [&](Instruction
*RegionStartI
,
3040 Instruction
*RegionEndI
) {
3041 LoopInfo
*LI
= nullptr;
3042 DominatorTree
*DT
= nullptr;
3043 MemorySSAUpdater
*MSU
= nullptr;
3044 using InsertPointTy
= OpenMPIRBuilder::InsertPointTy
;
3046 BasicBlock
*ParentBB
= RegionStartI
->getParent();
3047 Function
*Fn
= ParentBB
->getParent();
3048 Module
&M
= *Fn
->getParent();
3050 // Create all the blocks and logic.
3052 // goto RegionCheckTidBB
3053 // RegionCheckTidBB:
3054 // Tid = __kmpc_hardware_thread_id()
3056 // goto RegionBarrierBB
3058 // <execute instructions guarded>
3061 // <store escaping values to shared mem>
3062 // goto RegionBarrierBB
3064 // __kmpc_simple_barrier_spmd()
3065 // // second barrier is omitted if lacking escaping values.
3066 // <load escaping values from shared mem>
3067 // __kmpc_simple_barrier_spmd()
3068 // goto RegionExitBB
3070 // <execute rest of instructions>
3072 BasicBlock
*RegionEndBB
= SplitBlock(ParentBB
, RegionEndI
->getNextNode(),
3073 DT
, LI
, MSU
, "region.guarded.end");
3074 BasicBlock
*RegionBarrierBB
=
3075 SplitBlock(RegionEndBB
, &*RegionEndBB
->getFirstInsertionPt(), DT
, LI
,
3076 MSU
, "region.barrier");
3077 BasicBlock
*RegionExitBB
=
3078 SplitBlock(RegionBarrierBB
, &*RegionBarrierBB
->getFirstInsertionPt(),
3079 DT
, LI
, MSU
, "region.exit");
3080 BasicBlock
*RegionStartBB
=
3081 SplitBlock(ParentBB
, RegionStartI
, DT
, LI
, MSU
, "region.guarded");
3083 assert(ParentBB
->getUniqueSuccessor() == RegionStartBB
&&
3084 "Expected a different CFG");
3086 BasicBlock
*RegionCheckTidBB
= SplitBlock(
3087 ParentBB
, ParentBB
->getTerminator(), DT
, LI
, MSU
, "region.check.tid");
3089 // Register basic blocks with the Attributor.
3090 A
.registerManifestAddedBasicBlock(*RegionEndBB
);
3091 A
.registerManifestAddedBasicBlock(*RegionBarrierBB
);
3092 A
.registerManifestAddedBasicBlock(*RegionExitBB
);
3093 A
.registerManifestAddedBasicBlock(*RegionStartBB
);
3094 A
.registerManifestAddedBasicBlock(*RegionCheckTidBB
);
3096 bool HasBroadcastValues
= false;
3097 // Find escaping outputs from the guarded region to outside users and
3098 // broadcast their values to them.
3099 for (Instruction
&I
: *RegionStartBB
) {
3100 SmallPtrSet
<Instruction
*, 4> OutsideUsers
;
3101 for (User
*Usr
: I
.users()) {
3102 Instruction
&UsrI
= *cast
<Instruction
>(Usr
);
3103 if (UsrI
.getParent() != RegionStartBB
)
3104 OutsideUsers
.insert(&UsrI
);
3107 if (OutsideUsers
.empty())
3110 HasBroadcastValues
= true;
3112 // Emit a global variable in shared memory to store the broadcasted
3114 auto *SharedMem
= new GlobalVariable(
3115 M
, I
.getType(), /* IsConstant */ false,
3116 GlobalValue::InternalLinkage
, UndefValue::get(I
.getType()),
3117 I
.getName() + ".guarded.output.alloc", nullptr,
3118 GlobalValue::NotThreadLocal
,
3119 static_cast<unsigned>(AddressSpace::Shared
));
3121 // Emit a store instruction to update the value.
3122 new StoreInst(&I
, SharedMem
, RegionEndBB
->getTerminator());
3124 LoadInst
*LoadI
= new LoadInst(I
.getType(), SharedMem
,
3125 I
.getName() + ".guarded.output.load",
3126 RegionBarrierBB
->getTerminator());
3128 // Emit a load instruction and replace uses of the output value.
3129 for (Instruction
*UsrI
: OutsideUsers
) {
3130 assert(UsrI
->getParent() == RegionExitBB
&&
3131 "Expected escaping users in exit region");
3132 UsrI
->replaceUsesOfWith(&I
, LoadI
);
3136 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
3138 // Go to tid check BB in ParentBB.
3139 const DebugLoc DL
= ParentBB
->getTerminator()->getDebugLoc();
3140 ParentBB
->getTerminator()->eraseFromParent();
3141 OpenMPIRBuilder::LocationDescription
Loc(
3142 InsertPointTy(ParentBB
, ParentBB
->end()), DL
);
3143 OMPInfoCache
.OMPBuilder
.updateToLocation(Loc
);
3144 auto *SrcLocStr
= OMPInfoCache
.OMPBuilder
.getOrCreateSrcLocStr(Loc
);
3145 Value
*Ident
= OMPInfoCache
.OMPBuilder
.getOrCreateIdent(SrcLocStr
);
3146 BranchInst::Create(RegionCheckTidBB
, ParentBB
)->setDebugLoc(DL
);
3148 // Add check for Tid in RegionCheckTidBB
3149 RegionCheckTidBB
->getTerminator()->eraseFromParent();
3150 OpenMPIRBuilder::LocationDescription
LocRegionCheckTid(
3151 InsertPointTy(RegionCheckTidBB
, RegionCheckTidBB
->end()), DL
);
3152 OMPInfoCache
.OMPBuilder
.updateToLocation(LocRegionCheckTid
);
3153 FunctionCallee HardwareTidFn
=
3154 OMPInfoCache
.OMPBuilder
.getOrCreateRuntimeFunction(
3155 M
, OMPRTL___kmpc_get_hardware_thread_id_in_block
);
3157 OMPInfoCache
.OMPBuilder
.Builder
.CreateCall(HardwareTidFn
, {});
3158 Value
*TidCheck
= OMPInfoCache
.OMPBuilder
.Builder
.CreateIsNull(Tid
);
3159 OMPInfoCache
.OMPBuilder
.Builder
3160 .CreateCondBr(TidCheck
, RegionStartBB
, RegionBarrierBB
)
3163 // First barrier for synchronization, ensures main thread has updated
3165 FunctionCallee BarrierFn
=
3166 OMPInfoCache
.OMPBuilder
.getOrCreateRuntimeFunction(
3167 M
, OMPRTL___kmpc_barrier_simple_spmd
);
3168 OMPInfoCache
.OMPBuilder
.updateToLocation(InsertPointTy(
3169 RegionBarrierBB
, RegionBarrierBB
->getFirstInsertionPt()));
3170 OMPInfoCache
.OMPBuilder
.Builder
.CreateCall(BarrierFn
, {Ident
, Tid
})
3173 // Second barrier ensures workers have read broadcast values.
3174 if (HasBroadcastValues
)
3175 CallInst::Create(BarrierFn
, {Ident
, Tid
}, "",
3176 RegionBarrierBB
->getTerminator())
3180 SmallVector
<std::pair
<Instruction
*, Instruction
*>, 4> GuardedRegions
;
3182 for (Instruction
*GuardedI
: SPMDCompatibilityTracker
) {
3183 BasicBlock
*BB
= GuardedI
->getParent();
3184 auto *CalleeAA
= A
.lookupAAFor
<AAKernelInfo
>(
3185 IRPosition::function(*GuardedI
->getFunction()), nullptr,
3187 assert(CalleeAA
!= nullptr && "Expected Callee AAKernelInfo");
3188 auto &CalleeAAFunction
= *cast
<AAKernelInfoFunction
>(CalleeAA
);
3189 // Continue if instruction is already guarded.
3190 if (CalleeAAFunction
.getGuardedInstructions().contains(GuardedI
))
3193 Instruction
*GuardedRegionStart
= nullptr, *GuardedRegionEnd
= nullptr;
3194 for (Instruction
&I
: *BB
) {
3195 // If instruction I needs to be guarded update the guarded region
3197 if (SPMDCompatibilityTracker
.contains(&I
)) {
3198 CalleeAAFunction
.getGuardedInstructions().insert(&I
);
3199 if (GuardedRegionStart
)
3200 GuardedRegionEnd
= &I
;
3202 GuardedRegionStart
= GuardedRegionEnd
= &I
;
3207 // Instruction I does not need guarding, store
3208 // any region found and reset bounds.
3209 if (GuardedRegionStart
) {
3210 GuardedRegions
.push_back(
3211 std::make_pair(GuardedRegionStart
, GuardedRegionEnd
));
3212 GuardedRegionStart
= nullptr;
3213 GuardedRegionEnd
= nullptr;
3218 for (auto &GR
: GuardedRegions
)
3219 CreateGuardedRegion(GR
.first
, GR
.second
);
3221 // Adjust the global exec mode flag that tells the runtime what mode this
3222 // kernel is executed in.
3223 Function
*Kernel
= getAnchorScope();
3224 GlobalVariable
*ExecMode
= Kernel
->getParent()->getGlobalVariable(
3225 (Kernel
->getName() + "_exec_mode").str());
3226 assert(ExecMode
&& "Kernel without exec mode?");
3227 assert(ExecMode
->getInitializer() &&
3228 ExecMode
->getInitializer()->isOneValue() &&
3229 "Initially non-SPMD kernel has SPMD exec mode!");
3231 // Set the global exec mode flag to indicate SPMD-Generic mode.
3232 constexpr int SPMDGeneric
= 2;
3233 if (!ExecMode
->getInitializer()->isZeroValue())
3234 ExecMode
->setInitializer(
3235 ConstantInt::get(ExecMode
->getInitializer()->getType(), SPMDGeneric
));
3237 // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3238 const int InitIsSPMDArgNo
= 1;
3239 const int DeinitIsSPMDArgNo
= 1;
3240 const int InitUseStateMachineArgNo
= 2;
3241 const int InitRequiresFullRuntimeArgNo
= 3;
3242 const int DeinitRequiresFullRuntimeArgNo
= 2;
3244 auto &Ctx
= getAnchorValue().getContext();
3245 A
.changeUseAfterManifest(KernelInitCB
->getArgOperandUse(InitIsSPMDArgNo
),
3246 *ConstantInt::getBool(Ctx
, 1));
3247 A
.changeUseAfterManifest(
3248 KernelInitCB
->getArgOperandUse(InitUseStateMachineArgNo
),
3249 *ConstantInt::getBool(Ctx
, 0));
3250 A
.changeUseAfterManifest(
3251 KernelDeinitCB
->getArgOperandUse(DeinitIsSPMDArgNo
),
3252 *ConstantInt::getBool(Ctx
, 1));
3253 A
.changeUseAfterManifest(
3254 KernelInitCB
->getArgOperandUse(InitRequiresFullRuntimeArgNo
),
3255 *ConstantInt::getBool(Ctx
, 0));
3256 A
.changeUseAfterManifest(
3257 KernelDeinitCB
->getArgOperandUse(DeinitRequiresFullRuntimeArgNo
),
3258 *ConstantInt::getBool(Ctx
, 0));
3260 ++NumOpenMPTargetRegionKernelsSPMD
;
3262 auto Remark
= [&](OptimizationRemark OR
) {
3263 return OR
<< "Transformed generic-mode kernel to SPMD-mode.";
3265 A
.emitRemark
<OptimizationRemark
>(KernelInitCB
, "OMP120", Remark
);
3269 ChangeStatus
buildCustomStateMachine(Attributor
&A
) {
3270 // If we have disabled state machine rewrites, don't make a custom one
3271 if (DisableOpenMPOptStateMachineRewrite
)
3272 return indicatePessimisticFixpoint();
3274 assert(ReachedKnownParallelRegions
.isValidState() &&
3275 "Custom state machine with invalid parallel region states?");
3277 const int InitIsSPMDArgNo
= 1;
3278 const int InitUseStateMachineArgNo
= 2;
3280 // Check if the current configuration is non-SPMD and generic state machine.
3281 // If we already have SPMD mode or a custom state machine we do not need to
3282 // go any further. If it is anything but a constant something is weird and
3284 ConstantInt
*UseStateMachine
= dyn_cast
<ConstantInt
>(
3285 KernelInitCB
->getArgOperand(InitUseStateMachineArgNo
));
3286 ConstantInt
*IsSPMD
=
3287 dyn_cast
<ConstantInt
>(KernelInitCB
->getArgOperand(InitIsSPMDArgNo
));
3289 // If we are stuck with generic mode, try to create a custom device (=GPU)
3290 // state machine which is specialized for the parallel regions that are
3291 // reachable by the kernel.
3292 if (!UseStateMachine
|| UseStateMachine
->isZero() || !IsSPMD
||
3294 return ChangeStatus::UNCHANGED
;
3296 // If not SPMD mode, indicate we use a custom state machine now.
3297 auto &Ctx
= getAnchorValue().getContext();
3298 auto *FalseVal
= ConstantInt::getBool(Ctx
, 0);
3299 A
.changeUseAfterManifest(
3300 KernelInitCB
->getArgOperandUse(InitUseStateMachineArgNo
), *FalseVal
);
3302 // If we don't actually need a state machine we are done here. This can
3303 // happen if there simply are no parallel regions. In the resulting kernel
3304 // all worker threads will simply exit right away, leaving the main thread
3305 // to do the work alone.
3306 if (ReachedKnownParallelRegions
.empty() &&
3307 ReachedUnknownParallelRegions
.empty()) {
3308 ++NumOpenMPTargetRegionKernelsWithoutStateMachine
;
3310 auto Remark
= [&](OptimizationRemark OR
) {
3311 return OR
<< "Removing unused state machine from generic-mode kernel.";
3313 A
.emitRemark
<OptimizationRemark
>(KernelInitCB
, "OMP130", Remark
);
3315 return ChangeStatus::CHANGED
;
3318 // Keep track in the statistics of our new shiny custom state machine.
3319 if (ReachedUnknownParallelRegions
.empty()) {
3320 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback
;
3322 auto Remark
= [&](OptimizationRemark OR
) {
3323 return OR
<< "Rewriting generic-mode kernel with a customized state "
3326 A
.emitRemark
<OptimizationRemark
>(KernelInitCB
, "OMP131", Remark
);
3328 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback
;
3330 auto Remark
= [&](OptimizationRemarkAnalysis OR
) {
3331 return OR
<< "Generic-mode kernel is executed with a customized state "
3332 "machine that requires a fallback.";
3334 A
.emitRemark
<OptimizationRemarkAnalysis
>(KernelInitCB
, "OMP132", Remark
);
3336 // Tell the user why we ended up with a fallback.
3337 for (CallBase
*UnknownParallelRegionCB
: ReachedUnknownParallelRegions
) {
3338 if (!UnknownParallelRegionCB
)
3340 auto Remark
= [&](OptimizationRemarkAnalysis ORA
) {
3341 return ORA
<< "Call may contain unknown parallel regions. Use "
3342 << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3345 A
.emitRemark
<OptimizationRemarkAnalysis
>(UnknownParallelRegionCB
,
3350 // Create all the blocks:
3352 // InitCB = __kmpc_target_init(...)
3353 // bool IsWorker = InitCB >= 0;
3355 // SMBeginBB: __kmpc_barrier_simple_spmd(...);
3357 // bool Active = __kmpc_kernel_parallel(&WorkFn);
3358 // if (!WorkFn) return;
3359 // SMIsActiveCheckBB: if (Active) {
3360 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
3362 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
3365 // SMIfCascadeCurrentBB: else
3366 // ((WorkFnTy*)WorkFn)(...);
3367 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
3369 // SMDoneBB: __kmpc_barrier_simple_spmd(...);
3372 // UserCodeEntryBB: // user code
3373 // __kmpc_target_deinit(...)
3375 Function
*Kernel
= getAssociatedFunction();
3376 assert(Kernel
&& "Expected an associated function!");
3378 BasicBlock
*InitBB
= KernelInitCB
->getParent();
3379 BasicBlock
*UserCodeEntryBB
= InitBB
->splitBasicBlock(
3380 KernelInitCB
->getNextNode(), "thread.user_code.check");
3381 BasicBlock
*StateMachineBeginBB
= BasicBlock::Create(
3382 Ctx
, "worker_state_machine.begin", Kernel
, UserCodeEntryBB
);
3383 BasicBlock
*StateMachineFinishedBB
= BasicBlock::Create(
3384 Ctx
, "worker_state_machine.finished", Kernel
, UserCodeEntryBB
);
3385 BasicBlock
*StateMachineIsActiveCheckBB
= BasicBlock::Create(
3386 Ctx
, "worker_state_machine.is_active.check", Kernel
, UserCodeEntryBB
);
3387 BasicBlock
*StateMachineIfCascadeCurrentBB
=
3388 BasicBlock::Create(Ctx
, "worker_state_machine.parallel_region.check",
3389 Kernel
, UserCodeEntryBB
);
3390 BasicBlock
*StateMachineEndParallelBB
=
3391 BasicBlock::Create(Ctx
, "worker_state_machine.parallel_region.end",
3392 Kernel
, UserCodeEntryBB
);
3393 BasicBlock
*StateMachineDoneBarrierBB
= BasicBlock::Create(
3394 Ctx
, "worker_state_machine.done.barrier", Kernel
, UserCodeEntryBB
);
3395 A
.registerManifestAddedBasicBlock(*InitBB
);
3396 A
.registerManifestAddedBasicBlock(*UserCodeEntryBB
);
3397 A
.registerManifestAddedBasicBlock(*StateMachineBeginBB
);
3398 A
.registerManifestAddedBasicBlock(*StateMachineFinishedBB
);
3399 A
.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB
);
3400 A
.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB
);
3401 A
.registerManifestAddedBasicBlock(*StateMachineEndParallelBB
);
3402 A
.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB
);
3404 const DebugLoc
&DLoc
= KernelInitCB
->getDebugLoc();
3405 ReturnInst::Create(Ctx
, StateMachineFinishedBB
)->setDebugLoc(DLoc
);
3407 InitBB
->getTerminator()->eraseFromParent();
3408 Instruction
*IsWorker
=
3409 ICmpInst::Create(ICmpInst::ICmp
, llvm::CmpInst::ICMP_NE
, KernelInitCB
,
3410 ConstantInt::get(KernelInitCB
->getType(), -1),
3411 "thread.is_worker", InitBB
);
3412 IsWorker
->setDebugLoc(DLoc
);
3413 BranchInst::Create(StateMachineBeginBB
, UserCodeEntryBB
, IsWorker
, InitBB
);
3415 // Create local storage for the work function pointer.
3416 Type
*VoidPtrTy
= Type::getInt8PtrTy(Ctx
);
3417 AllocaInst
*WorkFnAI
= new AllocaInst(VoidPtrTy
, 0, "worker.work_fn.addr",
3418 &Kernel
->getEntryBlock().front());
3419 WorkFnAI
->setDebugLoc(DLoc
);
3421 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
3422 OMPInfoCache
.OMPBuilder
.updateToLocation(
3423 OpenMPIRBuilder::LocationDescription(
3424 IRBuilder
<>::InsertPoint(StateMachineBeginBB
,
3425 StateMachineBeginBB
->end()),
3428 Value
*Ident
= KernelInitCB
->getArgOperand(0);
3429 Value
*GTid
= KernelInitCB
;
3431 Module
&M
= *Kernel
->getParent();
3432 FunctionCallee BarrierFn
=
3433 OMPInfoCache
.OMPBuilder
.getOrCreateRuntimeFunction(
3434 M
, OMPRTL___kmpc_barrier_simple_spmd
);
3435 CallInst::Create(BarrierFn
, {Ident
, GTid
}, "", StateMachineBeginBB
)
3436 ->setDebugLoc(DLoc
);
3438 FunctionCallee KernelParallelFn
=
3439 OMPInfoCache
.OMPBuilder
.getOrCreateRuntimeFunction(
3440 M
, OMPRTL___kmpc_kernel_parallel
);
3441 Instruction
*IsActiveWorker
= CallInst::Create(
3442 KernelParallelFn
, {WorkFnAI
}, "worker.is_active", StateMachineBeginBB
);
3443 IsActiveWorker
->setDebugLoc(DLoc
);
3444 Instruction
*WorkFn
= new LoadInst(VoidPtrTy
, WorkFnAI
, "worker.work_fn",
3445 StateMachineBeginBB
);
3446 WorkFn
->setDebugLoc(DLoc
);
3448 FunctionType
*ParallelRegionFnTy
= FunctionType::get(
3449 Type::getVoidTy(Ctx
), {Type::getInt16Ty(Ctx
), Type::getInt32Ty(Ctx
)},
3451 Value
*WorkFnCast
= BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3452 WorkFn
, ParallelRegionFnTy
->getPointerTo(), "worker.work_fn.addr_cast",
3453 StateMachineBeginBB
);
3455 Instruction
*IsDone
=
3456 ICmpInst::Create(ICmpInst::ICmp
, llvm::CmpInst::ICMP_EQ
, WorkFn
,
3457 Constant::getNullValue(VoidPtrTy
), "worker.is_done",
3458 StateMachineBeginBB
);
3459 IsDone
->setDebugLoc(DLoc
);
3460 BranchInst::Create(StateMachineFinishedBB
, StateMachineIsActiveCheckBB
,
3461 IsDone
, StateMachineBeginBB
)
3462 ->setDebugLoc(DLoc
);
3464 BranchInst::Create(StateMachineIfCascadeCurrentBB
,
3465 StateMachineDoneBarrierBB
, IsActiveWorker
,
3466 StateMachineIsActiveCheckBB
)
3467 ->setDebugLoc(DLoc
);
3470 Constant::getNullValue(ParallelRegionFnTy
->getParamType(0));
3472 // Now that we have most of the CFG skeleton it is time for the if-cascade
3473 // that checks the function pointer we got from the runtime against the
3474 // parallel regions we expect, if there are any.
3475 for (int i
= 0, e
= ReachedKnownParallelRegions
.size(); i
< e
; ++i
) {
3476 auto *ParallelRegion
= ReachedKnownParallelRegions
[i
];
3477 BasicBlock
*PRExecuteBB
= BasicBlock::Create(
3478 Ctx
, "worker_state_machine.parallel_region.execute", Kernel
,
3479 StateMachineEndParallelBB
);
3480 CallInst::Create(ParallelRegion
, {ZeroArg
, GTid
}, "", PRExecuteBB
)
3481 ->setDebugLoc(DLoc
);
3482 BranchInst::Create(StateMachineEndParallelBB
, PRExecuteBB
)
3483 ->setDebugLoc(DLoc
);
3485 BasicBlock
*PRNextBB
=
3486 BasicBlock::Create(Ctx
, "worker_state_machine.parallel_region.check",
3487 Kernel
, StateMachineEndParallelBB
);
3489 // Check if we need to compare the pointer at all or if we can just
3490 // call the parallel region function.
3492 if (i
+ 1 < e
|| !ReachedUnknownParallelRegions
.empty()) {
3493 Instruction
*CmpI
= ICmpInst::Create(
3494 ICmpInst::ICmp
, llvm::CmpInst::ICMP_EQ
, WorkFnCast
, ParallelRegion
,
3495 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB
);
3496 CmpI
->setDebugLoc(DLoc
);
3499 IsPR
= ConstantInt::getTrue(Ctx
);
3502 BranchInst::Create(PRExecuteBB
, PRNextBB
, IsPR
,
3503 StateMachineIfCascadeCurrentBB
)
3504 ->setDebugLoc(DLoc
);
3505 StateMachineIfCascadeCurrentBB
= PRNextBB
;
3508 // At the end of the if-cascade we place the indirect function pointer call
3509 // in case we might need it, that is if there can be parallel regions we
3510 // have not handled in the if-cascade above.
3511 if (!ReachedUnknownParallelRegions
.empty()) {
3512 StateMachineIfCascadeCurrentBB
->setName(
3513 "worker_state_machine.parallel_region.fallback.execute");
3514 CallInst::Create(ParallelRegionFnTy
, WorkFnCast
, {ZeroArg
, GTid
}, "",
3515 StateMachineIfCascadeCurrentBB
)
3516 ->setDebugLoc(DLoc
);
3518 BranchInst::Create(StateMachineEndParallelBB
,
3519 StateMachineIfCascadeCurrentBB
)
3520 ->setDebugLoc(DLoc
);
3522 CallInst::Create(OMPInfoCache
.OMPBuilder
.getOrCreateRuntimeFunction(
3523 M
, OMPRTL___kmpc_kernel_end_parallel
),
3524 {}, "", StateMachineEndParallelBB
)
3525 ->setDebugLoc(DLoc
);
3526 BranchInst::Create(StateMachineDoneBarrierBB
, StateMachineEndParallelBB
)
3527 ->setDebugLoc(DLoc
);
3529 CallInst::Create(BarrierFn
, {Ident
, GTid
}, "", StateMachineDoneBarrierBB
)
3530 ->setDebugLoc(DLoc
);
3531 BranchInst::Create(StateMachineBeginBB
, StateMachineDoneBarrierBB
)
3532 ->setDebugLoc(DLoc
);
3534 return ChangeStatus::CHANGED
;
3537 /// Fixpoint iteration update function. Will be called every time a dependence
3538 /// changed its state (and in the beginning).
3539 ChangeStatus
updateImpl(Attributor
&A
) override
{
3540 KernelInfoState StateBefore
= getState();
3542 // Callback to check a read/write instruction.
3543 auto CheckRWInst
= [&](Instruction
&I
) {
3544 // We handle calls later.
3545 if (isa
<CallBase
>(I
))
3547 // We only care about write effects.
3548 if (!I
.mayWriteToMemory())
3550 if (auto *SI
= dyn_cast
<StoreInst
>(&I
)) {
3551 SmallVector
<const Value
*> Objects
;
3552 getUnderlyingObjects(SI
->getPointerOperand(), Objects
);
3553 if (llvm::all_of(Objects
,
3554 [](const Value
*Obj
) { return isa
<AllocaInst
>(Obj
); }))
3556 // Check for AAHeapToStack moved objects which must not be guarded.
3557 auto &HS
= A
.getAAFor
<AAHeapToStack
>(
3558 *this, IRPosition::function(*I
.getFunction()),
3559 DepClassTy::REQUIRED
);
3560 if (llvm::all_of(Objects
, [&HS
](const Value
*Obj
) {
3561 auto *CB
= dyn_cast
<CallBase
>(Obj
);
3564 return HS
.isAssumedHeapToStack(*CB
);
3570 // Insert instruction that needs guarding.
3571 SPMDCompatibilityTracker
.insert(&I
);
3575 bool UsedAssumedInformationInCheckRWInst
= false;
3576 if (!SPMDCompatibilityTracker
.isAtFixpoint())
3577 if (!A
.checkForAllReadWriteInstructions(
3578 CheckRWInst
, *this, UsedAssumedInformationInCheckRWInst
))
3579 SPMDCompatibilityTracker
.indicatePessimisticFixpoint();
3581 if (!IsKernelEntry
) {
3582 updateReachingKernelEntries(A
);
3583 updateParallelLevels(A
);
3585 if (!ParallelLevels
.isValidState())
3586 SPMDCompatibilityTracker
.indicatePessimisticFixpoint();
3589 // Callback to check a call instruction.
3590 bool AllSPMDStatesWereFixed
= true;
3591 auto CheckCallInst
= [&](Instruction
&I
) {
3592 auto &CB
= cast
<CallBase
>(I
);
3593 auto &CBAA
= A
.getAAFor
<AAKernelInfo
>(
3594 *this, IRPosition::callsite_function(CB
), DepClassTy::OPTIONAL
);
3595 getState() ^= CBAA
.getState();
3596 AllSPMDStatesWereFixed
&= CBAA
.SPMDCompatibilityTracker
.isAtFixpoint();
3600 bool UsedAssumedInformationInCheckCallInst
= false;
3601 if (!A
.checkForAllCallLikeInstructions(
3602 CheckCallInst
, *this, UsedAssumedInformationInCheckCallInst
))
3603 return indicatePessimisticFixpoint();
3605 // If we haven't used any assumed information for the SPMD state we can fix
3607 if (!UsedAssumedInformationInCheckRWInst
&&
3608 !UsedAssumedInformationInCheckCallInst
&& AllSPMDStatesWereFixed
)
3609 SPMDCompatibilityTracker
.indicateOptimisticFixpoint();
3611 return StateBefore
== getState() ? ChangeStatus::UNCHANGED
3612 : ChangeStatus::CHANGED
;
3616 /// Update info regarding reaching kernels.
3617 void updateReachingKernelEntries(Attributor
&A
) {
3618 auto PredCallSite
= [&](AbstractCallSite ACS
) {
3619 Function
*Caller
= ACS
.getInstruction()->getFunction();
3621 assert(Caller
&& "Caller is nullptr");
3623 auto &CAA
= A
.getOrCreateAAFor
<AAKernelInfo
>(
3624 IRPosition::function(*Caller
), this, DepClassTy::REQUIRED
);
3625 if (CAA
.ReachingKernelEntries
.isValidState()) {
3626 ReachingKernelEntries
^= CAA
.ReachingKernelEntries
;
3630 // We lost track of the caller of the associated function, any kernel
3632 ReachingKernelEntries
.indicatePessimisticFixpoint();
3637 bool AllCallSitesKnown
;
3638 if (!A
.checkForAllCallSites(PredCallSite
, *this,
3639 true /* RequireAllCallSites */,
3641 ReachingKernelEntries
.indicatePessimisticFixpoint();
3644 /// Update info regarding parallel levels.
3645 void updateParallelLevels(Attributor
&A
) {
3646 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
3647 OMPInformationCache::RuntimeFunctionInfo
&Parallel51RFI
=
3648 OMPInfoCache
.RFIs
[OMPRTL___kmpc_parallel_51
];
3650 auto PredCallSite
= [&](AbstractCallSite ACS
) {
3651 Function
*Caller
= ACS
.getInstruction()->getFunction();
3653 assert(Caller
&& "Caller is nullptr");
3656 A
.getOrCreateAAFor
<AAKernelInfo
>(IRPosition::function(*Caller
));
3657 if (CAA
.ParallelLevels
.isValidState()) {
3658 // Any function that is called by `__kmpc_parallel_51` will not be
3659 // folded as the parallel level in the function is updated. In order to
3660 // get it right, all the analysis would depend on the implentation. That
3661 // said, if in the future any change to the implementation, the analysis
3662 // could be wrong. As a consequence, we are just conservative here.
3663 if (Caller
== Parallel51RFI
.Declaration
) {
3664 ParallelLevels
.indicatePessimisticFixpoint();
3668 ParallelLevels
^= CAA
.ParallelLevels
;
3673 // We lost track of the caller of the associated function, any kernel
3675 ParallelLevels
.indicatePessimisticFixpoint();
3680 bool AllCallSitesKnown
= true;
3681 if (!A
.checkForAllCallSites(PredCallSite
, *this,
3682 true /* RequireAllCallSites */,
3684 ParallelLevels
.indicatePessimisticFixpoint();
3688 /// The call site kernel info abstract attribute, basically, what can we say
3689 /// about a call site with regards to the KernelInfoState. For now this simply
3690 /// forwards the information from the callee.
3691 struct AAKernelInfoCallSite
: AAKernelInfo
{
3692 AAKernelInfoCallSite(const IRPosition
&IRP
, Attributor
&A
)
3693 : AAKernelInfo(IRP
, A
) {}
3695 /// See AbstractAttribute::initialize(...).
3696 void initialize(Attributor
&A
) override
{
3697 AAKernelInfo::initialize(A
);
3699 CallBase
&CB
= cast
<CallBase
>(getAssociatedValue());
3700 Function
*Callee
= getAssociatedFunction();
3702 // Helper to lookup an assumption string.
3703 auto HasAssumption
= [](Function
*Fn
, StringRef AssumptionStr
) {
3704 return Fn
&& hasAssumption(*Fn
, AssumptionStr
);
3707 // Check for SPMD-mode assumptions.
3708 if (HasAssumption(Callee
, "ompx_spmd_amenable"))
3709 SPMDCompatibilityTracker
.indicateOptimisticFixpoint();
3711 // First weed out calls we do not care about, that is readonly/readnone
3712 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3713 // parallel region or anything else we are looking for.
3714 if (!CB
.mayWriteToMemory() || isa
<IntrinsicInst
>(CB
)) {
3715 indicateOptimisticFixpoint();
3719 // Next we check if we know the callee. If it is a known OpenMP function
3720 // we will handle them explicitly in the switch below. If it is not, we
3721 // will use an AAKernelInfo object on the callee to gather information and
3722 // merge that into the current state. The latter happens in the updateImpl.
3723 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
3724 const auto &It
= OMPInfoCache
.RuntimeFunctionIDMap
.find(Callee
);
3725 if (It
== OMPInfoCache
.RuntimeFunctionIDMap
.end()) {
3726 // Unknown caller or declarations are not analyzable, we give up.
3727 if (!Callee
|| !A
.isFunctionIPOAmendable(*Callee
)) {
3729 // Unknown callees might contain parallel regions, except if they have
3730 // an appropriate assumption attached.
3731 if (!(HasAssumption(Callee
, "omp_no_openmp") ||
3732 HasAssumption(Callee
, "omp_no_parallelism")))
3733 ReachedUnknownParallelRegions
.insert(&CB
);
3735 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3736 // idea we can run something unknown in SPMD-mode.
3737 if (!SPMDCompatibilityTracker
.isAtFixpoint()) {
3738 SPMDCompatibilityTracker
.indicatePessimisticFixpoint();
3739 SPMDCompatibilityTracker
.insert(&CB
);
3742 // We have updated the state for this unknown call properly, there won't
3743 // be any change so we indicate a fixpoint.
3744 indicateOptimisticFixpoint();
3746 // If the callee is known and can be used in IPO, we will update the state
3747 // based on the callee state in updateImpl.
3751 const unsigned int WrapperFunctionArgNo
= 6;
3752 RuntimeFunction RF
= It
->getSecond();
3754 // All the functions we know are compatible with SPMD mode.
3755 case OMPRTL___kmpc_is_spmd_exec_mode
:
3756 case OMPRTL___kmpc_for_static_fini
:
3757 case OMPRTL___kmpc_global_thread_num
:
3758 case OMPRTL___kmpc_get_hardware_num_threads_in_block
:
3759 case OMPRTL___kmpc_get_hardware_num_blocks
:
3760 case OMPRTL___kmpc_single
:
3761 case OMPRTL___kmpc_end_single
:
3762 case OMPRTL___kmpc_master
:
3763 case OMPRTL___kmpc_end_master
:
3764 case OMPRTL___kmpc_barrier
:
3766 case OMPRTL___kmpc_for_static_init_4
:
3767 case OMPRTL___kmpc_for_static_init_4u
:
3768 case OMPRTL___kmpc_for_static_init_8
:
3769 case OMPRTL___kmpc_for_static_init_8u
: {
3770 // Check the schedule and allow static schedule in SPMD mode.
3771 unsigned ScheduleArgOpNo
= 2;
3772 auto *ScheduleTypeCI
=
3773 dyn_cast
<ConstantInt
>(CB
.getArgOperand(ScheduleArgOpNo
));
3774 unsigned ScheduleTypeVal
=
3775 ScheduleTypeCI
? ScheduleTypeCI
->getZExtValue() : 0;
3776 switch (OMPScheduleType(ScheduleTypeVal
)) {
3777 case OMPScheduleType::Static
:
3778 case OMPScheduleType::StaticChunked
:
3779 case OMPScheduleType::Distribute
:
3780 case OMPScheduleType::DistributeChunked
:
3783 SPMDCompatibilityTracker
.indicatePessimisticFixpoint();
3784 SPMDCompatibilityTracker
.insert(&CB
);
3788 case OMPRTL___kmpc_target_init
:
3791 case OMPRTL___kmpc_target_deinit
:
3792 KernelDeinitCB
= &CB
;
3794 case OMPRTL___kmpc_parallel_51
:
3795 if (auto *ParallelRegion
= dyn_cast
<Function
>(
3796 CB
.getArgOperand(WrapperFunctionArgNo
)->stripPointerCasts())) {
3797 ReachedKnownParallelRegions
.insert(ParallelRegion
);
3800 // The condition above should usually get the parallel region function
3801 // pointer and record it. In the off chance it doesn't we assume the
3803 ReachedUnknownParallelRegions
.insert(&CB
);
3805 case OMPRTL___kmpc_omp_task
:
3806 // We do not look into tasks right now, just give up.
3807 SPMDCompatibilityTracker
.insert(&CB
);
3808 ReachedUnknownParallelRegions
.insert(&CB
);
3809 indicatePessimisticFixpoint();
3811 case OMPRTL___kmpc_alloc_shared
:
3812 case OMPRTL___kmpc_free_shared
:
3813 // Return without setting a fixpoint, to be resolved in updateImpl.
3816 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3818 SPMDCompatibilityTracker
.insert(&CB
);
3819 indicatePessimisticFixpoint();
3822 // All other OpenMP runtime calls will not reach parallel regions so they
3823 // can be safely ignored for now. Since it is a known OpenMP runtime call we
3824 // have now modeled all effects and there is no need for any update.
3825 indicateOptimisticFixpoint();
3828 ChangeStatus
updateImpl(Attributor
&A
) override
{
3829 // TODO: Once we have call site specific value information we can provide
3830 // call site specific liveness information and then it makes
3831 // sense to specialize attributes for call sites arguments instead of
3832 // redirecting requests to the callee argument.
3833 Function
*F
= getAssociatedFunction();
3835 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
3836 const auto &It
= OMPInfoCache
.RuntimeFunctionIDMap
.find(F
);
3838 // If F is not a runtime function, propagate the AAKernelInfo of the callee.
3839 if (It
== OMPInfoCache
.RuntimeFunctionIDMap
.end()) {
3840 const IRPosition
&FnPos
= IRPosition::function(*F
);
3841 auto &FnAA
= A
.getAAFor
<AAKernelInfo
>(*this, FnPos
, DepClassTy::REQUIRED
);
3842 if (getState() == FnAA
.getState())
3843 return ChangeStatus::UNCHANGED
;
3844 getState() = FnAA
.getState();
3845 return ChangeStatus::CHANGED
;
3848 // F is a runtime function that allocates or frees memory, check
3849 // AAHeapToStack and AAHeapToShared.
3850 KernelInfoState StateBefore
= getState();
3851 assert((It
->getSecond() == OMPRTL___kmpc_alloc_shared
||
3852 It
->getSecond() == OMPRTL___kmpc_free_shared
) &&
3853 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
3855 CallBase
&CB
= cast
<CallBase
>(getAssociatedValue());
3857 auto &HeapToStackAA
= A
.getAAFor
<AAHeapToStack
>(
3858 *this, IRPosition::function(*CB
.getCaller()), DepClassTy::OPTIONAL
);
3859 auto &HeapToSharedAA
= A
.getAAFor
<AAHeapToShared
>(
3860 *this, IRPosition::function(*CB
.getCaller()), DepClassTy::OPTIONAL
);
3862 RuntimeFunction RF
= It
->getSecond();
3865 // If neither HeapToStack nor HeapToShared assume the call is removed,
3866 // assume SPMD incompatibility.
3867 case OMPRTL___kmpc_alloc_shared
:
3868 if (!HeapToStackAA
.isAssumedHeapToStack(CB
) &&
3869 !HeapToSharedAA
.isAssumedHeapToShared(CB
))
3870 SPMDCompatibilityTracker
.insert(&CB
);
3872 case OMPRTL___kmpc_free_shared
:
3873 if (!HeapToStackAA
.isAssumedHeapToStackRemovedFree(CB
) &&
3874 !HeapToSharedAA
.isAssumedHeapToSharedRemovedFree(CB
))
3875 SPMDCompatibilityTracker
.insert(&CB
);
3878 SPMDCompatibilityTracker
.insert(&CB
);
3881 return StateBefore
== getState() ? ChangeStatus::UNCHANGED
3882 : ChangeStatus::CHANGED
;
3886 struct AAFoldRuntimeCall
3887 : public StateWrapper
<BooleanState
, AbstractAttribute
> {
3888 using Base
= StateWrapper
<BooleanState
, AbstractAttribute
>;
3890 AAFoldRuntimeCall(const IRPosition
&IRP
, Attributor
&A
) : Base(IRP
) {}
3892 /// Statistics are tracked as part of manifest for now.
3893 void trackStatistics() const override
{}
3895 /// Create an abstract attribute biew for the position \p IRP.
3896 static AAFoldRuntimeCall
&createForPosition(const IRPosition
&IRP
,
3899 /// See AbstractAttribute::getName()
3900 const std::string
getName() const override
{ return "AAFoldRuntimeCall"; }
3902 /// See AbstractAttribute::getIdAddr()
3903 const char *getIdAddr() const override
{ return &ID
; }
3905 /// This function should return true if the type of the \p AA is
3906 /// AAFoldRuntimeCall
3907 static bool classof(const AbstractAttribute
*AA
) {
3908 return (AA
->getIdAddr() == &ID
);
3911 static const char ID
;
3914 struct AAFoldRuntimeCallCallSiteReturned
: AAFoldRuntimeCall
{
3915 AAFoldRuntimeCallCallSiteReturned(const IRPosition
&IRP
, Attributor
&A
)
3916 : AAFoldRuntimeCall(IRP
, A
) {}
3918 /// See AbstractAttribute::getAsStr()
3919 const std::string
getAsStr() const override
{
3920 if (!isValidState())
3923 std::string
Str("simplified value: ");
3925 if (!SimplifiedValue
.hasValue())
3926 return Str
+ std::string("none");
3928 if (!SimplifiedValue
.getValue())
3929 return Str
+ std::string("nullptr");
3931 if (ConstantInt
*CI
= dyn_cast
<ConstantInt
>(SimplifiedValue
.getValue()))
3932 return Str
+ std::to_string(CI
->getSExtValue());
3934 return Str
+ std::string("unknown");
3937 void initialize(Attributor
&A
) override
{
3938 if (DisableOpenMPOptFolding
)
3939 indicatePessimisticFixpoint();
3941 Function
*Callee
= getAssociatedFunction();
3943 auto &OMPInfoCache
= static_cast<OMPInformationCache
&>(A
.getInfoCache());
3944 const auto &It
= OMPInfoCache
.RuntimeFunctionIDMap
.find(Callee
);
3945 assert(It
!= OMPInfoCache
.RuntimeFunctionIDMap
.end() &&
3946 "Expected a known OpenMP runtime function");
3948 RFKind
= It
->getSecond();
3950 CallBase
&CB
= cast
<CallBase
>(getAssociatedValue());
3951 A
.registerSimplificationCallback(
3952 IRPosition::callsite_returned(CB
),
3953 [&](const IRPosition
&IRP
, const AbstractAttribute
*AA
,
3954 bool &UsedAssumedInformation
) -> Optional
<Value
*> {
3955 assert((isValidState() || (SimplifiedValue
.hasValue() &&
3956 SimplifiedValue
.getValue() == nullptr)) &&
3957 "Unexpected invalid state!");
3959 if (!isAtFixpoint()) {
3960 UsedAssumedInformation
= true;
3962 A
.recordDependence(*this, *AA
, DepClassTy::OPTIONAL
);
3964 return SimplifiedValue
;
3968 ChangeStatus
updateImpl(Attributor
&A
) override
{
3969 ChangeStatus Changed
= ChangeStatus::UNCHANGED
;
3971 case OMPRTL___kmpc_is_spmd_exec_mode
:
3972 Changed
|= foldIsSPMDExecMode(A
);
3974 case OMPRTL___kmpc_is_generic_main_thread_id
:
3975 Changed
|= foldIsGenericMainThread(A
);
3977 case OMPRTL___kmpc_parallel_level
:
3978 Changed
|= foldParallelLevel(A
);
3980 case OMPRTL___kmpc_get_hardware_num_threads_in_block
:
3981 Changed
= Changed
| foldKernelFnAttribute(A
, "omp_target_thread_limit");
3983 case OMPRTL___kmpc_get_hardware_num_blocks
:
3984 Changed
= Changed
| foldKernelFnAttribute(A
, "omp_target_num_teams");
3987 llvm_unreachable("Unhandled OpenMP runtime function!");
3993 ChangeStatus
manifest(Attributor
&A
) override
{
3994 ChangeStatus Changed
= ChangeStatus::UNCHANGED
;
3996 if (SimplifiedValue
.hasValue() && SimplifiedValue
.getValue()) {
3997 Instruction
&CB
= *getCtxI();
3998 A
.changeValueAfterManifest(CB
, **SimplifiedValue
);
3999 A
.deleteAfterManifest(CB
);
4001 LLVM_DEBUG(dbgs() << TAG
<< "Folding runtime call: " << CB
<< " with "
4002 << **SimplifiedValue
<< "\n");
4004 Changed
= ChangeStatus::CHANGED
;
4010 ChangeStatus
indicatePessimisticFixpoint() override
{
4011 SimplifiedValue
= nullptr;
4012 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4016 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4017 ChangeStatus
foldIsSPMDExecMode(Attributor
&A
) {
4018 Optional
<Value
*> SimplifiedValueBefore
= SimplifiedValue
;
4020 unsigned AssumedSPMDCount
= 0, KnownSPMDCount
= 0;
4021 unsigned AssumedNonSPMDCount
= 0, KnownNonSPMDCount
= 0;
4022 auto &CallerKernelInfoAA
= A
.getAAFor
<AAKernelInfo
>(
4023 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED
);
4025 if (!CallerKernelInfoAA
.ReachingKernelEntries
.isValidState())
4026 return indicatePessimisticFixpoint();
4028 for (Kernel K
: CallerKernelInfoAA
.ReachingKernelEntries
) {
4029 auto &AA
= A
.getAAFor
<AAKernelInfo
>(*this, IRPosition::function(*K
),
4030 DepClassTy::REQUIRED
);
4032 if (!AA
.isValidState()) {
4033 SimplifiedValue
= nullptr;
4034 return indicatePessimisticFixpoint();
4037 if (AA
.SPMDCompatibilityTracker
.isAssumed()) {
4038 if (AA
.SPMDCompatibilityTracker
.isAtFixpoint())
4043 if (AA
.SPMDCompatibilityTracker
.isAtFixpoint())
4044 ++KnownNonSPMDCount
;
4046 ++AssumedNonSPMDCount
;
4050 if ((AssumedSPMDCount
+ KnownSPMDCount
) &&
4051 (AssumedNonSPMDCount
+ KnownNonSPMDCount
))
4052 return indicatePessimisticFixpoint();
4054 auto &Ctx
= getAnchorValue().getContext();
4055 if (KnownSPMDCount
|| AssumedSPMDCount
) {
4056 assert(KnownNonSPMDCount
== 0 && AssumedNonSPMDCount
== 0 &&
4057 "Expected only SPMD kernels!");
4058 // All reaching kernels are in SPMD mode. Update all function calls to
4059 // __kmpc_is_spmd_exec_mode to 1.
4060 SimplifiedValue
= ConstantInt::get(Type::getInt8Ty(Ctx
), true);
4061 } else if (KnownNonSPMDCount
|| AssumedNonSPMDCount
) {
4062 assert(KnownSPMDCount
== 0 && AssumedSPMDCount
== 0 &&
4063 "Expected only non-SPMD kernels!");
4064 // All reaching kernels are in non-SPMD mode. Update all function
4065 // calls to __kmpc_is_spmd_exec_mode to 0.
4066 SimplifiedValue
= ConstantInt::get(Type::getInt8Ty(Ctx
), false);
4068 // We have empty reaching kernels, therefore we cannot tell if the
4069 // associated call site can be folded. At this moment, SimplifiedValue
4071 assert(!SimplifiedValue
.hasValue() && "SimplifiedValue should be none");
4074 return SimplifiedValue
== SimplifiedValueBefore
? ChangeStatus::UNCHANGED
4075 : ChangeStatus::CHANGED
;
4078 /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4079 ChangeStatus
foldIsGenericMainThread(Attributor
&A
) {
4080 Optional
<Value
*> SimplifiedValueBefore
= SimplifiedValue
;
4082 CallBase
&CB
= cast
<CallBase
>(getAssociatedValue());
4083 Function
*F
= CB
.getFunction();
4084 const auto &ExecutionDomainAA
= A
.getAAFor
<AAExecutionDomain
>(
4085 *this, IRPosition::function(*F
), DepClassTy::REQUIRED
);
4087 if (!ExecutionDomainAA
.isValidState())
4088 return indicatePessimisticFixpoint();
4090 auto &Ctx
= getAnchorValue().getContext();
4091 if (ExecutionDomainAA
.isExecutedByInitialThreadOnly(CB
))
4092 SimplifiedValue
= ConstantInt::get(Type::getInt8Ty(Ctx
), true);
4094 return indicatePessimisticFixpoint();
4096 return SimplifiedValue
== SimplifiedValueBefore
? ChangeStatus::UNCHANGED
4097 : ChangeStatus::CHANGED
;
4100 /// Fold __kmpc_parallel_level into a constant if possible.
4101 ChangeStatus
foldParallelLevel(Attributor
&A
) {
4102 Optional
<Value
*> SimplifiedValueBefore
= SimplifiedValue
;
4104 auto &CallerKernelInfoAA
= A
.getAAFor
<AAKernelInfo
>(
4105 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED
);
4107 if (!CallerKernelInfoAA
.ParallelLevels
.isValidState())
4108 return indicatePessimisticFixpoint();
4110 if (!CallerKernelInfoAA
.ReachingKernelEntries
.isValidState())
4111 return indicatePessimisticFixpoint();
4113 if (CallerKernelInfoAA
.ReachingKernelEntries
.empty()) {
4114 assert(!SimplifiedValue
.hasValue() &&
4115 "SimplifiedValue should keep none at this point");
4116 return ChangeStatus::UNCHANGED
;
4119 unsigned AssumedSPMDCount
= 0, KnownSPMDCount
= 0;
4120 unsigned AssumedNonSPMDCount
= 0, KnownNonSPMDCount
= 0;
4121 for (Kernel K
: CallerKernelInfoAA
.ReachingKernelEntries
) {
4122 auto &AA
= A
.getAAFor
<AAKernelInfo
>(*this, IRPosition::function(*K
),
4123 DepClassTy::REQUIRED
);
4124 if (!AA
.SPMDCompatibilityTracker
.isValidState())
4125 return indicatePessimisticFixpoint();
4127 if (AA
.SPMDCompatibilityTracker
.isAssumed()) {
4128 if (AA
.SPMDCompatibilityTracker
.isAtFixpoint())
4133 if (AA
.SPMDCompatibilityTracker
.isAtFixpoint())
4134 ++KnownNonSPMDCount
;
4136 ++AssumedNonSPMDCount
;
4140 if ((AssumedSPMDCount
+ KnownSPMDCount
) &&
4141 (AssumedNonSPMDCount
+ KnownNonSPMDCount
))
4142 return indicatePessimisticFixpoint();
4144 auto &Ctx
= getAnchorValue().getContext();
4145 // If the caller can only be reached by SPMD kernel entries, the parallel
4146 // level is 1. Similarly, if the caller can only be reached by non-SPMD
4147 // kernel entries, it is 0.
4148 if (AssumedSPMDCount
|| KnownSPMDCount
) {
4149 assert(KnownNonSPMDCount
== 0 && AssumedNonSPMDCount
== 0 &&
4150 "Expected only SPMD kernels!");
4151 SimplifiedValue
= ConstantInt::get(Type::getInt8Ty(Ctx
), 1);
4153 assert(KnownSPMDCount
== 0 && AssumedSPMDCount
== 0 &&
4154 "Expected only non-SPMD kernels!");
4155 SimplifiedValue
= ConstantInt::get(Type::getInt8Ty(Ctx
), 0);
4157 return SimplifiedValue
== SimplifiedValueBefore
? ChangeStatus::UNCHANGED
4158 : ChangeStatus::CHANGED
;
4161 ChangeStatus
foldKernelFnAttribute(Attributor
&A
, llvm::StringRef Attr
) {
4162 // Specialize only if all the calls agree with the attribute constant value
4163 int32_t CurrentAttrValue
= -1;
4164 Optional
<Value
*> SimplifiedValueBefore
= SimplifiedValue
;
4166 auto &CallerKernelInfoAA
= A
.getAAFor
<AAKernelInfo
>(
4167 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED
);
4169 if (!CallerKernelInfoAA
.ReachingKernelEntries
.isValidState())
4170 return indicatePessimisticFixpoint();
4172 // Iterate over the kernels that reach this function
4173 for (Kernel K
: CallerKernelInfoAA
.ReachingKernelEntries
) {
4174 int32_t NextAttrVal
= -1;
4175 if (K
->hasFnAttribute(Attr
))
4177 std::stoi(K
->getFnAttribute(Attr
).getValueAsString().str());
4179 if (NextAttrVal
== -1 ||
4180 (CurrentAttrValue
!= -1 && CurrentAttrValue
!= NextAttrVal
))
4181 return indicatePessimisticFixpoint();
4182 CurrentAttrValue
= NextAttrVal
;
4185 if (CurrentAttrValue
!= -1) {
4186 auto &Ctx
= getAnchorValue().getContext();
4188 ConstantInt::get(Type::getInt32Ty(Ctx
), CurrentAttrValue
);
4190 return SimplifiedValue
== SimplifiedValueBefore
? ChangeStatus::UNCHANGED
4191 : ChangeStatus::CHANGED
;
4194 /// An optional value the associated value is assumed to fold to. That is, we
4195 /// assume the associated value (which is a call) can be replaced by this
4196 /// simplified value.
4197 Optional
<Value
*> SimplifiedValue
;
4199 /// The runtime function kind of the callee of the associated call site.
4200 RuntimeFunction RFKind
;
4205 /// Register folding callsite
4206 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF
) {
4207 auto &RFI
= OMPInfoCache
.RFIs
[RF
];
4208 RFI
.foreachUse(SCC
, [&](Use
&U
, Function
&F
) {
4209 CallInst
*CI
= OpenMPOpt::getCallIfRegularCall(U
, &RFI
);
4212 A
.getOrCreateAAFor
<AAFoldRuntimeCall
>(
4213 IRPosition::callsite_returned(*CI
), /* QueryingAA */ nullptr,
4214 DepClassTy::NONE
, /* ForceUpdate */ false,
4215 /* UpdateAfterInit */ false);
4220 void OpenMPOpt::registerAAs(bool IsModulePass
) {
4225 // Ensure we create the AAKernelInfo AAs first and without triggering an
4226 // update. This will make sure we register all value simplification
4227 // callbacks before any other AA has the chance to create an AAValueSimplify
4229 for (Function
*Kernel
: OMPInfoCache
.Kernels
)
4230 A
.getOrCreateAAFor
<AAKernelInfo
>(
4231 IRPosition::function(*Kernel
), /* QueryingAA */ nullptr,
4232 DepClassTy::NONE
, /* ForceUpdate */ false,
4233 /* UpdateAfterInit */ false);
4236 registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id
);
4237 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode
);
4238 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level
);
4239 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block
);
4240 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks
);
4243 // Create CallSite AA for all Getters.
4244 for (int Idx
= 0; Idx
< OMPInfoCache
.ICVs
.size() - 1; ++Idx
) {
4245 auto ICVInfo
= OMPInfoCache
.ICVs
[static_cast<InternalControlVar
>(Idx
)];
4247 auto &GetterRFI
= OMPInfoCache
.RFIs
[ICVInfo
.Getter
];
4249 auto CreateAA
= [&](Use
&U
, Function
&Caller
) {
4250 CallInst
*CI
= OpenMPOpt::getCallIfRegularCall(U
, &GetterRFI
);
4254 auto &CB
= cast
<CallBase
>(*CI
);
4256 IRPosition CBPos
= IRPosition::callsite_function(CB
);
4257 A
.getOrCreateAAFor
<AAICVTracker
>(CBPos
);
4261 GetterRFI
.foreachUse(SCC
, CreateAA
);
4263 auto &GlobalizationRFI
= OMPInfoCache
.RFIs
[OMPRTL___kmpc_alloc_shared
];
4264 auto CreateAA
= [&](Use
&U
, Function
&F
) {
4265 A
.getOrCreateAAFor
<AAHeapToShared
>(IRPosition::function(F
));
4268 if (!DisableOpenMPOptDeglobalization
)
4269 GlobalizationRFI
.foreachUse(SCC
, CreateAA
);
4271 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4272 // every function if there is a device kernel.
4273 if (!isOpenMPDevice(M
))
4276 for (auto *F
: SCC
) {
4277 if (F
->isDeclaration())
4280 A
.getOrCreateAAFor
<AAExecutionDomain
>(IRPosition::function(*F
));
4281 if (!DisableOpenMPOptDeglobalization
)
4282 A
.getOrCreateAAFor
<AAHeapToStack
>(IRPosition::function(*F
));
4284 for (auto &I
: instructions(*F
)) {
4285 if (auto *LI
= dyn_cast
<LoadInst
>(&I
)) {
4286 bool UsedAssumedInformation
= false;
4287 A
.getAssumedSimplified(IRPosition::value(*LI
), /* AA */ nullptr,
4288 UsedAssumedInformation
);
4294 const char AAICVTracker::ID
= 0;
4295 const char AAKernelInfo::ID
= 0;
4296 const char AAExecutionDomain::ID
= 0;
4297 const char AAHeapToShared::ID
= 0;
4298 const char AAFoldRuntimeCall::ID
= 0;
4300 AAICVTracker
&AAICVTracker::createForPosition(const IRPosition
&IRP
,
4302 AAICVTracker
*AA
= nullptr;
4303 switch (IRP
.getPositionKind()) {
4304 case IRPosition::IRP_INVALID
:
4305 case IRPosition::IRP_FLOAT
:
4306 case IRPosition::IRP_ARGUMENT
:
4307 case IRPosition::IRP_CALL_SITE_ARGUMENT
:
4308 llvm_unreachable("ICVTracker can only be created for function position!");
4309 case IRPosition::IRP_RETURNED
:
4310 AA
= new (A
.Allocator
) AAICVTrackerFunctionReturned(IRP
, A
);
4312 case IRPosition::IRP_CALL_SITE_RETURNED
:
4313 AA
= new (A
.Allocator
) AAICVTrackerCallSiteReturned(IRP
, A
);
4315 case IRPosition::IRP_CALL_SITE
:
4316 AA
= new (A
.Allocator
) AAICVTrackerCallSite(IRP
, A
);
4318 case IRPosition::IRP_FUNCTION
:
4319 AA
= new (A
.Allocator
) AAICVTrackerFunction(IRP
, A
);
4326 AAExecutionDomain
&AAExecutionDomain::createForPosition(const IRPosition
&IRP
,
4328 AAExecutionDomainFunction
*AA
= nullptr;
4329 switch (IRP
.getPositionKind()) {
4330 case IRPosition::IRP_INVALID
:
4331 case IRPosition::IRP_FLOAT
:
4332 case IRPosition::IRP_ARGUMENT
:
4333 case IRPosition::IRP_CALL_SITE_ARGUMENT
:
4334 case IRPosition::IRP_RETURNED
:
4335 case IRPosition::IRP_CALL_SITE_RETURNED
:
4336 case IRPosition::IRP_CALL_SITE
:
4338 "AAExecutionDomain can only be created for function position!");
4339 case IRPosition::IRP_FUNCTION
:
4340 AA
= new (A
.Allocator
) AAExecutionDomainFunction(IRP
, A
);
4347 AAHeapToShared
&AAHeapToShared::createForPosition(const IRPosition
&IRP
,
4349 AAHeapToSharedFunction
*AA
= nullptr;
4350 switch (IRP
.getPositionKind()) {
4351 case IRPosition::IRP_INVALID
:
4352 case IRPosition::IRP_FLOAT
:
4353 case IRPosition::IRP_ARGUMENT
:
4354 case IRPosition::IRP_CALL_SITE_ARGUMENT
:
4355 case IRPosition::IRP_RETURNED
:
4356 case IRPosition::IRP_CALL_SITE_RETURNED
:
4357 case IRPosition::IRP_CALL_SITE
:
4359 "AAHeapToShared can only be created for function position!");
4360 case IRPosition::IRP_FUNCTION
:
4361 AA
= new (A
.Allocator
) AAHeapToSharedFunction(IRP
, A
);
4368 AAKernelInfo
&AAKernelInfo::createForPosition(const IRPosition
&IRP
,
4370 AAKernelInfo
*AA
= nullptr;
4371 switch (IRP
.getPositionKind()) {
4372 case IRPosition::IRP_INVALID
:
4373 case IRPosition::IRP_FLOAT
:
4374 case IRPosition::IRP_ARGUMENT
:
4375 case IRPosition::IRP_RETURNED
:
4376 case IRPosition::IRP_CALL_SITE_RETURNED
:
4377 case IRPosition::IRP_CALL_SITE_ARGUMENT
:
4378 llvm_unreachable("KernelInfo can only be created for function position!");
4379 case IRPosition::IRP_CALL_SITE
:
4380 AA
= new (A
.Allocator
) AAKernelInfoCallSite(IRP
, A
);
4382 case IRPosition::IRP_FUNCTION
:
4383 AA
= new (A
.Allocator
) AAKernelInfoFunction(IRP
, A
);
4390 AAFoldRuntimeCall
&AAFoldRuntimeCall::createForPosition(const IRPosition
&IRP
,
4392 AAFoldRuntimeCall
*AA
= nullptr;
4393 switch (IRP
.getPositionKind()) {
4394 case IRPosition::IRP_INVALID
:
4395 case IRPosition::IRP_FLOAT
:
4396 case IRPosition::IRP_ARGUMENT
:
4397 case IRPosition::IRP_RETURNED
:
4398 case IRPosition::IRP_FUNCTION
:
4399 case IRPosition::IRP_CALL_SITE
:
4400 case IRPosition::IRP_CALL_SITE_ARGUMENT
:
4401 llvm_unreachable("KernelInfo can only be created for call site position!");
4402 case IRPosition::IRP_CALL_SITE_RETURNED
:
4403 AA
= new (A
.Allocator
) AAFoldRuntimeCallCallSiteReturned(IRP
, A
);
4410 PreservedAnalyses
OpenMPOptPass::run(Module
&M
, ModuleAnalysisManager
&AM
) {
4411 if (!containsOpenMP(M
))
4412 return PreservedAnalyses::all();
4413 if (DisableOpenMPOptimizations
)
4414 return PreservedAnalyses::all();
4416 FunctionAnalysisManager
&FAM
=
4417 AM
.getResult
<FunctionAnalysisManagerModuleProxy
>(M
).getManager();
4418 KernelSet Kernels
= getDeviceKernels(M
);
4420 auto IsCalled
= [&](Function
&F
) {
4421 if (Kernels
.contains(&F
))
4423 for (const User
*U
: F
.users())
4424 if (!isa
<BlockAddress
>(U
))
4429 auto EmitRemark
= [&](Function
&F
) {
4430 auto &ORE
= FAM
.getResult
<OptimizationRemarkEmitterAnalysis
>(F
);
4432 OptimizationRemarkAnalysis
ORA(DEBUG_TYPE
, "OMP140", &F
);
4433 return ORA
<< "Could not internalize function. "
4434 << "Some optimizations may not be possible. [OMP140]";
4438 // Create internal copies of each function if this is a kernel Module. This
4439 // allows iterprocedural passes to see every call edge.
4440 DenseMap
<Function
*, Function
*> InternalizedMap
;
4441 if (isOpenMPDevice(M
)) {
4442 SmallPtrSet
<Function
*, 16> InternalizeFns
;
4443 for (Function
&F
: M
)
4444 if (!F
.isDeclaration() && !Kernels
.contains(&F
) && IsCalled(F
) &&
4445 !DisableInternalization
) {
4446 if (Attributor::isInternalizable(F
)) {
4447 InternalizeFns
.insert(&F
);
4448 } else if (!F
.hasLocalLinkage() && !F
.hasFnAttribute(Attribute::Cold
)) {
4453 Attributor::internalizeFunctions(InternalizeFns
, InternalizedMap
);
4456 // Look at every function in the Module unless it was internalized.
4457 SmallVector
<Function
*, 16> SCC
;
4458 for (Function
&F
: M
)
4459 if (!F
.isDeclaration() && !InternalizedMap
.lookup(&F
))
4463 return PreservedAnalyses::all();
4465 AnalysisGetter
AG(FAM
);
4467 auto OREGetter
= [&FAM
](Function
*F
) -> OptimizationRemarkEmitter
& {
4468 return FAM
.getResult
<OptimizationRemarkEmitterAnalysis
>(*F
);
4471 BumpPtrAllocator Allocator
;
4472 CallGraphUpdater CGUpdater
;
4474 SetVector
<Function
*> Functions(SCC
.begin(), SCC
.end());
4475 OMPInformationCache
InfoCache(M
, AG
, Allocator
, /*CGSCC*/ Functions
, Kernels
);
4477 unsigned MaxFixpointIterations
= (isOpenMPDevice(M
)) ? 128 : 32;
4478 Attributor
A(Functions
, InfoCache
, CGUpdater
, nullptr, true, false,
4479 MaxFixpointIterations
, OREGetter
, DEBUG_TYPE
);
4481 OpenMPOpt
OMPOpt(SCC
, CGUpdater
, OREGetter
, InfoCache
, A
);
4482 bool Changed
= OMPOpt
.run(true);
4484 if (PrintModuleAfterOptimizations
)
4485 LLVM_DEBUG(dbgs() << TAG
<< "Module after OpenMPOpt Module Pass:\n" << M
);
4488 return PreservedAnalyses::none();
4490 return PreservedAnalyses::all();
4493 PreservedAnalyses
OpenMPOptCGSCCPass::run(LazyCallGraph::SCC
&C
,
4494 CGSCCAnalysisManager
&AM
,
4496 CGSCCUpdateResult
&UR
) {
4497 if (!containsOpenMP(*C
.begin()->getFunction().getParent()))
4498 return PreservedAnalyses::all();
4499 if (DisableOpenMPOptimizations
)
4500 return PreservedAnalyses::all();
4502 SmallVector
<Function
*, 16> SCC
;
4503 // If there are kernels in the module, we have to run on all SCC's.
4504 for (LazyCallGraph::Node
&N
: C
) {
4505 Function
*Fn
= &N
.getFunction();
4510 return PreservedAnalyses::all();
4512 Module
&M
= *C
.begin()->getFunction().getParent();
4514 KernelSet Kernels
= getDeviceKernels(M
);
4516 FunctionAnalysisManager
&FAM
=
4517 AM
.getResult
<FunctionAnalysisManagerCGSCCProxy
>(C
, CG
).getManager();
4519 AnalysisGetter
AG(FAM
);
4521 auto OREGetter
= [&FAM
](Function
*F
) -> OptimizationRemarkEmitter
& {
4522 return FAM
.getResult
<OptimizationRemarkEmitterAnalysis
>(*F
);
4525 BumpPtrAllocator Allocator
;
4526 CallGraphUpdater CGUpdater
;
4527 CGUpdater
.initialize(CG
, C
, AM
, UR
);
4529 SetVector
<Function
*> Functions(SCC
.begin(), SCC
.end());
4530 OMPInformationCache
InfoCache(*(Functions
.back()->getParent()), AG
, Allocator
,
4531 /*CGSCC*/ Functions
, Kernels
);
4533 unsigned MaxFixpointIterations
= (isOpenMPDevice(M
)) ? 128 : 32;
4534 Attributor
A(Functions
, InfoCache
, CGUpdater
, nullptr, false, true,
4535 MaxFixpointIterations
, OREGetter
, DEBUG_TYPE
);
4537 OpenMPOpt
OMPOpt(SCC
, CGUpdater
, OREGetter
, InfoCache
, A
);
4538 bool Changed
= OMPOpt
.run(false);
4540 if (PrintModuleAfterOptimizations
)
4541 LLVM_DEBUG(dbgs() << TAG
<< "Module after OpenMPOpt CGSCC Pass:\n" << M
);
4544 return PreservedAnalyses::none();
4546 return PreservedAnalyses::all();
4551 struct OpenMPOptCGSCCLegacyPass
: public CallGraphSCCPass
{
4552 CallGraphUpdater CGUpdater
;
4555 OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID
) {
4556 initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
4559 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
4560 CallGraphSCCPass::getAnalysisUsage(AU
);
4563 bool runOnSCC(CallGraphSCC
&CGSCC
) override
{
4564 if (!containsOpenMP(CGSCC
.getCallGraph().getModule()))
4566 if (DisableOpenMPOptimizations
|| skipSCC(CGSCC
))
4569 SmallVector
<Function
*, 16> SCC
;
4570 // If there are kernels in the module, we have to run on all SCC's.
4571 for (CallGraphNode
*CGN
: CGSCC
) {
4572 Function
*Fn
= CGN
->getFunction();
4573 if (!Fn
|| Fn
->isDeclaration())
4581 Module
&M
= CGSCC
.getCallGraph().getModule();
4582 KernelSet Kernels
= getDeviceKernels(M
);
4584 CallGraph
&CG
= getAnalysis
<CallGraphWrapperPass
>().getCallGraph();
4585 CGUpdater
.initialize(CG
, CGSCC
);
4587 // Maintain a map of functions to avoid rebuilding the ORE
4588 DenseMap
<Function
*, std::unique_ptr
<OptimizationRemarkEmitter
>> OREMap
;
4589 auto OREGetter
= [&OREMap
](Function
*F
) -> OptimizationRemarkEmitter
& {
4590 std::unique_ptr
<OptimizationRemarkEmitter
> &ORE
= OREMap
[F
];
4592 ORE
= std::make_unique
<OptimizationRemarkEmitter
>(F
);
4597 SetVector
<Function
*> Functions(SCC
.begin(), SCC
.end());
4598 BumpPtrAllocator Allocator
;
4599 OMPInformationCache
InfoCache(*(Functions
.back()->getParent()), AG
,
4601 /*CGSCC*/ Functions
, Kernels
);
4603 unsigned MaxFixpointIterations
= (isOpenMPDevice(M
)) ? 128 : 32;
4604 Attributor
A(Functions
, InfoCache
, CGUpdater
, nullptr, false, true,
4605 MaxFixpointIterations
, OREGetter
, DEBUG_TYPE
);
4607 OpenMPOpt
OMPOpt(SCC
, CGUpdater
, OREGetter
, InfoCache
, A
);
4608 bool Result
= OMPOpt
.run(false);
4610 if (PrintModuleAfterOptimizations
)
4611 LLVM_DEBUG(dbgs() << TAG
<< "Module after OpenMPOpt CGSCC Pass:\n" << M
);
4616 bool doFinalization(CallGraph
&CG
) override
{ return CGUpdater
.finalize(); }
4619 } // end anonymous namespace
4621 KernelSet
llvm::omp::getDeviceKernels(Module
&M
) {
4622 // TODO: Create a more cross-platform way of determining device kernels.
4623 NamedMDNode
*MD
= M
.getOrInsertNamedMetadata("nvvm.annotations");
4629 for (auto *Op
: MD
->operands()) {
4630 if (Op
->getNumOperands() < 2)
4632 MDString
*KindID
= dyn_cast
<MDString
>(Op
->getOperand(1));
4633 if (!KindID
|| KindID
->getString() != "kernel")
4636 Function
*KernelFn
=
4637 mdconst::dyn_extract_or_null
<Function
>(Op
->getOperand(0));
4641 ++NumOpenMPTargetRegionKernels
;
4643 Kernels
.insert(KernelFn
);
4649 bool llvm::omp::containsOpenMP(Module
&M
) {
4650 Metadata
*MD
= M
.getModuleFlag("openmp");
4657 bool llvm::omp::isOpenMPDevice(Module
&M
) {
4658 Metadata
*MD
= M
.getModuleFlag("openmp-device");
4665 char OpenMPOptCGSCCLegacyPass::ID
= 0;
4667 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass
, "openmp-opt-cgscc",
4668 "OpenMP specific optimizations", false, false)
4669 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass
)
4670 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass
, "openmp-opt-cgscc",
4671 "OpenMP specific optimizations", false, false)
4673 Pass
*llvm::createOpenMPOptCGSCCLegacyPass() {
4674 return new OpenMPOptCGSCCLegacyPass();