1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file This pass does attempts to make use of reqd_work_group_size metadata
10 /// to eliminate loads from the dispatch packet and to constant fold OpenCL
11 /// get_local_size-like functions.
13 //===----------------------------------------------------------------------===//
16 #include "llvm/Analysis/ValueTracking.h"
17 #include "llvm/CodeGen/Passes.h"
18 #include "llvm/CodeGen/TargetPassConfig.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/InstIterator.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/IntrinsicsAMDGPU.h"
24 #include "llvm/IR/PatternMatch.h"
25 #include "llvm/Pass.h"
27 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
33 // Field offsets in hsa_kernel_dispatch_packet_t.
34 enum DispatchPackedOffsets
{
44 class AMDGPULowerKernelAttributes
: public ModulePass
{
48 AMDGPULowerKernelAttributes() : ModulePass(ID
) {}
50 bool runOnModule(Module
&M
) override
;
52 StringRef
getPassName() const override
{
53 return "AMDGPU Kernel Attributes";
56 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
61 } // end anonymous namespace
63 static bool processUse(CallInst
*CI
) {
64 Function
*F
= CI
->getParent()->getParent();
66 auto MD
= F
->getMetadata("reqd_work_group_size");
67 const bool HasReqdWorkGroupSize
= MD
&& MD
->getNumOperands() == 3;
69 const bool HasUniformWorkGroupSize
=
70 F
->getFnAttribute("uniform-work-group-size").getValueAsBool();
72 if (!HasReqdWorkGroupSize
&& !HasUniformWorkGroupSize
)
75 Value
*WorkGroupSizeX
= nullptr;
76 Value
*WorkGroupSizeY
= nullptr;
77 Value
*WorkGroupSizeZ
= nullptr;
79 Value
*GridSizeX
= nullptr;
80 Value
*GridSizeY
= nullptr;
81 Value
*GridSizeZ
= nullptr;
83 const DataLayout
&DL
= F
->getParent()->getDataLayout();
85 // We expect to see several GEP users, casted to the appropriate type and
87 for (User
*U
: CI
->users()) {
92 if (GetPointerBaseWithConstantOffset(U
, Offset
, DL
) != CI
)
95 auto *BCI
= dyn_cast
<BitCastInst
>(*U
->user_begin());
96 if (!BCI
|| !BCI
->hasOneUse())
99 auto *Load
= dyn_cast
<LoadInst
>(*BCI
->user_begin());
100 if (!Load
|| !Load
->isSimple())
103 unsigned LoadSize
= DL
.getTypeStoreSize(Load
->getType());
105 // TODO: Handle merged loads.
107 case WORKGROUP_SIZE_X
:
109 WorkGroupSizeX
= Load
;
111 case WORKGROUP_SIZE_Y
:
113 WorkGroupSizeY
= Load
;
115 case WORKGROUP_SIZE_Z
:
117 WorkGroupSizeZ
= Load
;
136 // Pattern match the code used to handle partial workgroup dispatches in the
137 // library implementation of get_local_size, so the entire function can be
138 // constant folded with a known group size.
140 // uint r = grid_size - group_id * group_size;
141 // get_local_size = (r < group_size) ? r : group_size;
143 // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
144 // the grid_size is required to be a multiple of group_size). In this case:
146 // grid_size - (group_id * group_size) < group_size
148 // grid_size < group_size + (group_id * group_size)
150 // (grid_size / group_size) < 1 + group_id
152 // grid_size / group_size is at least 1, so we can conclude the select
153 // condition is false (except for group_id == 0, where the select result is
156 bool MadeChange
= false;
157 Value
*WorkGroupSizes
[3] = { WorkGroupSizeX
, WorkGroupSizeY
, WorkGroupSizeZ
};
158 Value
*GridSizes
[3] = { GridSizeX
, GridSizeY
, GridSizeZ
};
160 for (int I
= 0; HasUniformWorkGroupSize
&& I
< 3; ++I
) {
161 Value
*GroupSize
= WorkGroupSizes
[I
];
162 Value
*GridSize
= GridSizes
[I
];
163 if (!GroupSize
|| !GridSize
)
166 for (User
*U
: GroupSize
->users()) {
167 auto *ZextGroupSize
= dyn_cast
<ZExtInst
>(U
);
171 for (User
*ZextUser
: ZextGroupSize
->users()) {
172 auto *SI
= dyn_cast
<SelectInst
>(ZextUser
);
176 using namespace llvm::PatternMatch
;
177 auto GroupIDIntrin
= I
== 0 ?
178 m_Intrinsic
<Intrinsic::amdgcn_workgroup_id_x
>() :
179 (I
== 1 ? m_Intrinsic
<Intrinsic::amdgcn_workgroup_id_y
>() :
180 m_Intrinsic
<Intrinsic::amdgcn_workgroup_id_z
>());
182 auto SubExpr
= m_Sub(m_Specific(GridSize
),
183 m_Mul(GroupIDIntrin
, m_Specific(ZextGroupSize
)));
185 ICmpInst::Predicate Pred
;
187 m_Select(m_ICmp(Pred
, SubExpr
, m_Specific(ZextGroupSize
)),
189 m_Specific(ZextGroupSize
))) &&
190 Pred
== ICmpInst::ICMP_ULT
) {
191 if (HasReqdWorkGroupSize
) {
192 ConstantInt
*KnownSize
193 = mdconst::extract
<ConstantInt
>(MD
->getOperand(I
));
194 SI
->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize
,
198 SI
->replaceAllUsesWith(ZextGroupSize
);
207 if (!HasReqdWorkGroupSize
)
210 // Eliminate any other loads we can from the dispatch packet.
211 for (int I
= 0; I
< 3; ++I
) {
212 Value
*GroupSize
= WorkGroupSizes
[I
];
216 ConstantInt
*KnownSize
= mdconst::extract
<ConstantInt
>(MD
->getOperand(I
));
217 GroupSize
->replaceAllUsesWith(
218 ConstantExpr::getIntegerCast(KnownSize
,
219 GroupSize
->getType(),
227 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
228 // TargetPassConfig for subtarget.
229 bool AMDGPULowerKernelAttributes::runOnModule(Module
&M
) {
230 StringRef DispatchPtrName
231 = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr
);
233 Function
*DispatchPtr
= M
.getFunction(DispatchPtrName
);
234 if (!DispatchPtr
) // Dispatch ptr not used.
237 bool MadeChange
= false;
239 SmallPtrSet
<Instruction
*, 4> HandledUses
;
240 for (auto *U
: DispatchPtr
->users()) {
241 CallInst
*CI
= cast
<CallInst
>(U
);
242 if (HandledUses
.insert(CI
).second
) {
251 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes
, DEBUG_TYPE
,
252 "AMDGPU Kernel Attributes", false, false)
253 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes
, DEBUG_TYPE
,
254 "AMDGPU Kernel Attributes", false, false)
256 char AMDGPULowerKernelAttributes::ID
= 0;
258 ModulePass
*llvm::createAMDGPULowerKernelAttributesPass() {
259 return new AMDGPULowerKernelAttributes();
263 AMDGPULowerKernelAttributesPass::run(Function
&F
, FunctionAnalysisManager
&AM
) {
264 StringRef DispatchPtrName
=
265 Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr
);
267 Function
*DispatchPtr
= F
.getParent()->getFunction(DispatchPtrName
);
268 if (!DispatchPtr
) // Dispatch ptr not used.
269 return PreservedAnalyses::all();
271 for (Instruction
&I
: instructions(F
)) {
272 if (CallInst
*CI
= dyn_cast
<CallInst
>(&I
)) {
273 if (CI
->getCalledFunction() == DispatchPtr
)
278 return PreservedAnalyses::all();