1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// \file This pass does attempts to make use of reqd_work_group_size metadata
10 /// to eliminate loads from the dispatch packet and to constant fold OpenCL
11 /// get_local_size-like functions.
13 //===----------------------------------------------------------------------===//
16 #include "AMDGPUTargetMachine.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/CodeGen/TargetPassConfig.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/PatternMatch.h"
24 #include "llvm/Pass.h"
26 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
32 // Field offsets in hsa_kernel_dispatch_packet_t.
33 enum DispatchPackedOffsets
{
43 class AMDGPULowerKernelAttributes
: public ModulePass
{
44 Module
*Mod
= nullptr;
49 AMDGPULowerKernelAttributes() : ModulePass(ID
) {}
51 bool processUse(CallInst
*CI
);
53 bool doInitialization(Module
&M
) override
;
54 bool runOnModule(Module
&M
) override
;
56 StringRef
getPassName() const override
{
57 return "AMDGPU Kernel Attributes";
60 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
65 } // end anonymous namespace
67 bool AMDGPULowerKernelAttributes::doInitialization(Module
&M
) {
72 bool AMDGPULowerKernelAttributes::processUse(CallInst
*CI
) {
73 Function
*F
= CI
->getParent()->getParent();
75 auto MD
= F
->getMetadata("reqd_work_group_size");
76 const bool HasReqdWorkGroupSize
= MD
&& MD
->getNumOperands() == 3;
78 const bool HasUniformWorkGroupSize
=
79 F
->getFnAttribute("uniform-work-group-size").getValueAsString() == "true";
81 if (!HasReqdWorkGroupSize
&& !HasUniformWorkGroupSize
)
84 Value
*WorkGroupSizeX
= nullptr;
85 Value
*WorkGroupSizeY
= nullptr;
86 Value
*WorkGroupSizeZ
= nullptr;
88 Value
*GridSizeX
= nullptr;
89 Value
*GridSizeY
= nullptr;
90 Value
*GridSizeZ
= nullptr;
92 const DataLayout
&DL
= Mod
->getDataLayout();
94 // We expect to see several GEP users, casted to the appropriate type and
96 for (User
*U
: CI
->users()) {
101 if (GetPointerBaseWithConstantOffset(U
, Offset
, DL
) != CI
)
104 auto *BCI
= dyn_cast
<BitCastInst
>(*U
->user_begin());
105 if (!BCI
|| !BCI
->hasOneUse())
108 auto *Load
= dyn_cast
<LoadInst
>(*BCI
->user_begin());
109 if (!Load
|| !Load
->isSimple())
112 unsigned LoadSize
= DL
.getTypeStoreSize(Load
->getType());
114 // TODO: Handle merged loads.
116 case WORKGROUP_SIZE_X
:
118 WorkGroupSizeX
= Load
;
120 case WORKGROUP_SIZE_Y
:
122 WorkGroupSizeY
= Load
;
124 case WORKGROUP_SIZE_Z
:
126 WorkGroupSizeZ
= Load
;
145 // Pattern match the code used to handle partial workgroup dispatches in the
146 // library implementation of get_local_size, so the entire function can be
147 // constant folded with a known group size.
149 // uint r = grid_size - group_id * group_size;
150 // get_local_size = (r < group_size) ? r : group_size;
152 // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
153 // the grid_size is required to be a multiple of group_size). In this case:
155 // grid_size - (group_id * group_size) < group_size
157 // grid_size < group_size + (group_id * group_size)
159 // (grid_size / group_size) < 1 + group_id
161 // grid_size / group_size is at least 1, so we can conclude the select
162 // condition is false (except for group_id == 0, where the select result is
165 bool MadeChange
= false;
166 Value
*WorkGroupSizes
[3] = { WorkGroupSizeX
, WorkGroupSizeY
, WorkGroupSizeZ
};
167 Value
*GridSizes
[3] = { GridSizeX
, GridSizeY
, GridSizeZ
};
169 for (int I
= 0; HasUniformWorkGroupSize
&& I
< 3; ++I
) {
170 Value
*GroupSize
= WorkGroupSizes
[I
];
171 Value
*GridSize
= GridSizes
[I
];
172 if (!GroupSize
|| !GridSize
)
175 for (User
*U
: GroupSize
->users()) {
176 auto *ZextGroupSize
= dyn_cast
<ZExtInst
>(U
);
180 for (User
*ZextUser
: ZextGroupSize
->users()) {
181 auto *SI
= dyn_cast
<SelectInst
>(ZextUser
);
185 using namespace llvm::PatternMatch
;
186 auto GroupIDIntrin
= I
== 0 ?
187 m_Intrinsic
<Intrinsic::amdgcn_workgroup_id_x
>() :
188 (I
== 1 ? m_Intrinsic
<Intrinsic::amdgcn_workgroup_id_y
>() :
189 m_Intrinsic
<Intrinsic::amdgcn_workgroup_id_z
>());
191 auto SubExpr
= m_Sub(m_Specific(GridSize
),
192 m_Mul(GroupIDIntrin
, m_Specific(ZextGroupSize
)));
194 ICmpInst::Predicate Pred
;
196 m_Select(m_ICmp(Pred
, SubExpr
, m_Specific(ZextGroupSize
)),
198 m_Specific(ZextGroupSize
))) &&
199 Pred
== ICmpInst::ICMP_ULT
) {
200 if (HasReqdWorkGroupSize
) {
201 ConstantInt
*KnownSize
202 = mdconst::extract
<ConstantInt
>(MD
->getOperand(I
));
203 SI
->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize
,
207 SI
->replaceAllUsesWith(ZextGroupSize
);
216 if (!HasReqdWorkGroupSize
)
219 // Eliminate any other loads we can from the dispatch packet.
220 for (int I
= 0; I
< 3; ++I
) {
221 Value
*GroupSize
= WorkGroupSizes
[I
];
225 ConstantInt
*KnownSize
= mdconst::extract
<ConstantInt
>(MD
->getOperand(I
));
226 GroupSize
->replaceAllUsesWith(
227 ConstantExpr::getIntegerCast(KnownSize
,
228 GroupSize
->getType(),
236 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
237 // TargetPassConfig for subtarget.
238 bool AMDGPULowerKernelAttributes::runOnModule(Module
&M
) {
239 StringRef DispatchPtrName
240 = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr
);
242 Function
*DispatchPtr
= Mod
->getFunction(DispatchPtrName
);
243 if (!DispatchPtr
) // Dispatch ptr not used.
246 bool MadeChange
= false;
248 SmallPtrSet
<Instruction
*, 4> HandledUses
;
249 for (auto *U
: DispatchPtr
->users()) {
250 CallInst
*CI
= cast
<CallInst
>(U
);
251 if (HandledUses
.insert(CI
).second
) {
260 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes
, DEBUG_TYPE
,
261 "AMDGPU IR optimizations", false, false)
262 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes
, DEBUG_TYPE
, "AMDGPU IR optimizations",
265 char AMDGPULowerKernelAttributes::ID
= 0;
267 ModulePass
*llvm::createAMDGPULowerKernelAttributesPass() {
268 return new AMDGPULowerKernelAttributes();